Add support to constant dense vector/tensor attribute.

The syntax of dense vecor/tensor attribute value is

`dense<` (tensor-type | vector-type)`,` attribute-list`>`

and

attribute-list ::= `[` attribute-list (`, ` attribute-list)* `]`.

The construction of the dense vector/tensor attribute takes a vector/tensor
type and a character array as arguments. The size of the input array should be
larger than the size specified by the type argument. It also assumes the
elements of the vector or tensor have been trunked to the data type sizes in
the input character array, so it extends the trunked data to 64 bits when it is
retrieved.

PiperOrigin-RevId: 217762811
This commit is contained in:
Feng Liu 2018-10-18 13:54:44 -07:00 committed by jpienaar
parent 18e666702c
commit b5b90e5465
12 changed files with 844 additions and 134 deletions

View File

@ -46,14 +46,14 @@ public:
Function,
SplatElements,
DenseIntElements,
DenseFPElements,
FIRST_ELEMENTS_ATTR = SplatElements,
LAST_ELEMENTS_ATTR = SplatElements,
LAST_ELEMENTS_ATTR = DenseFPElements,
};
/// Return the classification for this attribute.
Kind getKind() const {
return kind;
}
Kind getKind() const { return kind; }
/// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const { return isOrContainsFunctionCache; }
@ -74,8 +74,8 @@ private:
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
Attribute(const Attribute&) = delete;
void operator=(const Attribute&) = delete;
Attribute(const Attribute &) = delete;
void operator=(const Attribute &) = delete;
};
inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) {
@ -87,14 +87,13 @@ class BoolAttr : public Attribute {
public:
static BoolAttr *get(bool value, MLIRContext *context);
bool getValue() const {
return value;
}
bool getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Bool;
}
private:
BoolAttr(bool value)
: Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
@ -106,14 +105,13 @@ class IntegerAttr : public Attribute {
public:
static IntegerAttr *get(int64_t value, MLIRContext *context);
int64_t getValue() const {
return value;
}
int64_t getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Integer;
}
private:
IntegerAttr(int64_t value)
: Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
@ -130,14 +128,13 @@ public:
// correctness, otherwise constant folding will be done with host math. This
// is completely incorrect for BF16 and other datatypes, and subtly wrong
// for float32.
double getValue() const {
return value;
}
double getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Float;
}
private:
FloatAttr(double value)
: Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {}
@ -149,14 +146,13 @@ class StringAttr : public Attribute {
public:
static StringAttr *get(StringRef bytes, MLIRContext *context);
StringRef getValue() const {
return value;
}
StringRef getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::String;
}
private:
StringAttr(StringRef value)
: Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
@ -168,21 +164,20 @@ private:
/// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute {
public:
static ArrayAttr *get(ArrayRef<Attribute*> value, MLIRContext *context);
static ArrayAttr *get(ArrayRef<Attribute *> value, MLIRContext *context);
ArrayRef<Attribute*> getValue() const {
return value;
}
ArrayRef<Attribute *> getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Array;
}
private:
ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
: Attribute(Kind::Array, isOrContainsFunction), value(value) {}
~ArrayAttr() = delete;
ArrayRef<Attribute*> value;
ArrayRef<Attribute *> value;
};
class AffineMapAttr : public Attribute {
@ -289,6 +284,98 @@ private:
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
Attribute *elt;
};
/// An attribute represents a reference to a dense vector or tensor object.
///
/// This class is designed to store elements with any bit widths equal or less
/// than 64.
class DenseElementsAttr : public ElementsAttr {
public:
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary.
static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the
// attribute.
static DenseElementsAttr *get(VectorOrTensorType *type,
ArrayRef<Attribute *> values);
void getValues(SmallVectorImpl<Attribute *> &values) const;
ArrayRef<char> getRawData() const { return data; }
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseIntElements ||
attr->getKind() == Kind::DenseFPElements;
}
protected:
DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef<char> data)
: ElementsAttr(kind, type), data(data) {}
private:
ArrayRef<char> data;
};
/// An attribute represents a reference to a dense integer vector or tensor
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
size_t bitsWidth)
: DenseElementsAttr(Kind::DenseIntElements, type, data),
bitsWidth(bitsWidth) {}
// TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
APInt getValue(ArrayRef<int> indices) const;
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
/// in array `rawData`.
static void writeBits(char *rawData, size_t bitPos, size_t bitWidth,
uint64_t value);
/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
/// `rawData` and return them as the lowest bits of an uint64 integer.
static uint64_t readBits(const char *rawData, size_t bitPos,
size_t bitsWidth);
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseIntElements;
}
private:
~DenseIntElementsAttr() = delete;
size_t bitsWidth;
};
/// An attribute represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
// TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
APFloat getValue(ArrayRef<int> indices) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseFPElements;
}
private:
~DenseFPElementsAttr() = delete;
};
} // end namespace mlir.
#endif

View File

@ -100,6 +100,8 @@ public:
TypeAttr *getTypeAttr(Type *type);
FunctionAttr *getFunctionAttr(const Function *value);
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);

View File

@ -92,6 +92,10 @@ public:
/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;
/// Return the bitwidth of this type. For vector or tensor types, returns the
/// element type's bitwidth.
unsigned getBitWidth() const;
// Convenience factories.
static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
static FloatType *getBF16(MLIRContext *ctx);
@ -293,6 +297,10 @@ class VectorOrTensorType : public Type {
public:
Type *getElementType() const { return elementType; }
/// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor or vector, abort.
unsigned getNumElements() const;
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int getRank() const;
@ -466,7 +474,6 @@ static bool isValidTensorElementType(Type *type) {
return isa<FloatType>(type) || isa<VectorType>(type) ||
isa<IntegerType>(type) || isa<OtherType>(type);
}
} // end namespace mlir
#endif // MLIR_IR_TYPES_H

View File

@ -41,8 +41,7 @@ namespace mlir {
template <typename ForwardIterator, typename UnaryFunctor,
typename NullaryFunctor>
inline void interleave(ForwardIterator begin, ForwardIterator end,
UnaryFunctor each_fn,
NullaryFunctor between_fn) {
UnaryFunctor each_fn, NullaryFunctor between_fn) {
if (begin == end)
return;
each_fn(*begin);
@ -59,6 +58,11 @@ inline void interleave(const Container &c, UnaryFunctor each_fn,
interleave(c.begin(), c.end(), each_fn, between_fn);
}
template <typename T, template <typename> class Container, typename raw_ostream>
inline void interleaveComma(const Container<T> &c, raw_ostream &os) {
interleave(c.begin(), c.end(), [&](T a) { os << a; }, [&]() { os << ", "; });
}
} // end namespace mlir
// Allow tuples to be usable as DenseMap keys.
@ -80,8 +84,7 @@ static inline unsigned llvm_combineHashValue(unsigned a, unsigned b) {
}
namespace llvm {
template<typename ...Ts>
struct DenseMapInfo<std::tuple<Ts...> > {
template <typename... Ts> struct DenseMapInfo<std::tuple<Ts...>> {
typedef std::tuple<Ts...> Tuple;
static inline Tuple getEmptyKey() {
@ -92,34 +95,34 @@ struct DenseMapInfo<std::tuple<Ts...> > {
return Tuple(DenseMapInfo<Ts>::getTombstoneKey()...);
}
template<unsigned I>
static unsigned getHashValueImpl(const Tuple& values, std::false_type) {
template <unsigned I>
static unsigned getHashValueImpl(const Tuple &values, std::false_type) {
typedef typename std::tuple_element<I, Tuple>::type EltType;
std::integral_constant<bool, I+1 == sizeof...(Ts)> atEnd;
std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
return llvm_combineHashValue(
DenseMapInfo<EltType>::getHashValue(std::get<I>(values)),
getHashValueImpl<I+1>(values, atEnd));
DenseMapInfo<EltType>::getHashValue(std::get<I>(values)),
getHashValueImpl<I + 1>(values, atEnd));
}
template<unsigned I>
static unsigned getHashValueImpl(const Tuple& values, std::true_type) {
template <unsigned I>
static unsigned getHashValueImpl(const Tuple &values, std::true_type) {
return 0;
}
static unsigned getHashValue(const std::tuple<Ts...>& values) {
static unsigned getHashValue(const std::tuple<Ts...> &values) {
std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
return getHashValueImpl<0>(values, atEnd);
}
template<unsigned I>
template <unsigned I>
static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type) {
typedef typename std::tuple_element<I, Tuple>::type EltType;
std::integral_constant<bool, I+1 == sizeof...(Ts)> atEnd;
return DenseMapInfo<EltType>::isEqual(std::get<I>(lhs), std::get<I>(rhs))
&& isEqualImpl<I+1>(lhs, rhs, atEnd);
std::integral_constant<bool, I + 1 == sizeof...(Ts)> atEnd;
return DenseMapInfo<EltType>::isEqual(std::get<I>(lhs), std::get<I>(rhs)) &&
isEqualImpl<I + 1>(lhs, rhs, atEnd);
}
template<unsigned I>
template <unsigned I>
static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type) {
return true;
}

View File

@ -296,6 +296,7 @@ protected:
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet integerSet);
void printDenseElementsAttr(const DenseElementsAttr *attr);
/// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExprStorage is being printed in, so we can
@ -457,6 +458,16 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
}
break;
}
case Attribute::Kind::DenseIntElements:
case Attribute::Kind::DenseFPElements: {
auto *eltsAttr = cast<DenseElementsAttr>(attr);
os << "dense<";
printType(eltsAttr->getType());
os << ", ";
printDenseElementsAttr(eltsAttr);
os << '>';
break;
}
case Attribute::Kind::SplatElements: {
auto *elementsAttr = cast<SplatElementsAttr>(attr);
os << "splat<";
@ -469,6 +480,59 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
}
}
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) {
auto *type = attr->getType();
auto shape = type->getShape();
auto rank = type->getRank();
SmallVector<Attribute *, 16> elements;
attr->getValues(elements);
// Special case for degenerate tensors.
if (elements.empty()) {
for (int i = 0; i < rank; ++i)
os << '[';
for (int i = 0; i < rank; ++i)
os << ']';
return;
}
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
// The mixed-radix counter, with radices in 'shape'.
SmallVector<unsigned, 4> counter(rank, 0);
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;
auto bumpCounter = [&]() {
// Bump the least significant digit.
++counter[rank - 1];
// Iterate backwards bubbling back the increment.
for (unsigned i = rank - 1; i > 0; --i)
if (counter[i] >= shape[i]) {
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
counter[i] = 0;
++counter[i - 1];
--openBrackets;
os << ']';
}
};
for (unsigned idx = 0, e = elements.size(); idx != e; ++idx) {
if (idx != 0)
os << ", ";
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
printAttribute(elements[idx]);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
}
void ModulePrinter::printType(const Type *type) {
switch (type->getKind()) {
case Type::Kind::Index:

View File

@ -149,6 +149,11 @@ ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
return SplatElementsAttr::get(type, elt);
}
ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integet Sets.
//===----------------------------------------------------------------------===//

View File

@ -32,6 +32,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
@ -180,6 +181,22 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
}
};
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttr *>::getHashValue;
using DenseMapInfo<DenseElementsAttr *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
}
};
} // end anonymous namespace.
namespace mlir {
@ -277,6 +294,9 @@ public:
DenseMap<const Function *, FunctionAttr *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
splatElementsAttrs;
using DenseElementsAttrSet =
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
DenseElementsAttrSet denseElementsAttrs;
public:
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {}
@ -798,6 +818,139 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires");
auto &impl = type->getContext()->getImpl();
// Look to see if this constant is already defined.
DenseElementsAttrInfo::KeyTy key({type, data});
auto existing = impl.denseElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value.
if (!existing.second)
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
auto *eltType = type->getElementType();
switch (eltType->getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
auto *result = impl.allocator.Allocate<DenseFPElementsAttr>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseFPElementsAttr(type, {copy, data.size()});
return *existing.first = result;
}
case Type::Kind::Integer: {
auto width = cast<IntegerType>(eltType)->getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttr>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseIntElementsAttr(type, {copy, data.size()}, width);
return *existing.first = result;
}
default:
llvm_unreachable("unexpected element type");
}
}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`.
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) {
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute *> &values) const {
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>(this)->getValues(values);
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>(this)->getValues(values);
return;
default:
llvm_unreachable("unexpected element type");
}
}
void DenseIntElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto *attr = IntegerAttr::get(value, context);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto *attr = IntegerAttr::get(value.getSExtValue(), context);
values.push_back(attr);
}
}
}
void DenseFPElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto *attr = FloatAttr::get(v, context);
values.push_back(attr);
}
}
ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) {
auto &impl = type->getContext()->getImpl();

View File

@ -17,13 +17,35 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/AffineMap.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
unsigned Type::getBitWidth() const {
switch (getKind()) {
// TODO: Currently the IR uses host double type to store all the float
// datatypes. This is completely incorrect for BF16 and other datatypes.
// We have to fix this once APFloat is used in the IR.
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64:
return 64;
case Type::Kind::Integer:
return cast<IntegerType>(this)->getWidth();
case Type::Kind::Vector:
case Type::Kind::RankedTensor:
case Type::Kind::UnrankedTensor:
return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
// TODO: Handle more types.
default:
llvm_unreachable("unexpected type");
}
}
IntegerType::IntegerType(unsigned width, MLIRContext *context)
: Type(Kind::Integer, context), width(width) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
: Type(Kind::Integer, context), width(width) {
assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
}
FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
@ -32,25 +54,39 @@ OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
unsigned numResults, MLIRContext *context)
: Type(Kind::Function, context, numInputs),
numResults(numResults), inputsAndResults(inputsAndResults) {
}
: Type(Kind::Function, context, numInputs), numResults(numResults),
inputsAndResults(inputsAndResults) {}
VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
Type *elementType, unsigned subClassData)
: Type(kind, context, subClassData), elementType(elementType) {}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
case Kind::Vector:
case Kind::RankedTensor: {
auto shape = getShape();
unsigned num = 1;
for (auto dim : shape)
num *= dim;
return num;
}
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int VectorOrTensorType::getRank() const {
switch (getKind()) {
default:
llvm_unreachable("not a VectorOrTensorType");
case Kind::Vector:
case Kind::RankedTensor:
return getShape().size();
case Kind::UnrankedTensor:
return -1;
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
@ -60,7 +96,7 @@ int VectorOrTensorType::getDimSize(unsigned i) const {
case Kind::RankedTensor:
return getShape()[i];
default:
llvm_unreachable("not a VectorOrTensorType");
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
@ -94,14 +130,13 @@ TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: TensorType(Kind::RankedTensor, elementType, context),
shapeElements(shape.data()) {
: TensorType(Kind::RankedTensor, elementType, context),
shapeElements(shape.data()) {
setSubclassData(shape.size());
}
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
: TensorType(Kind::UnrankedTensor, elementType, context) {
}
: TensorType(Kind::UnrankedTensor, elementType, context) {}
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,

View File

@ -35,6 +35,8 @@
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/PrettyStackTrace.h"
@ -205,6 +207,8 @@ public:
AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference();
ElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
VectorOrTensorType *parseVectorOrTensorType();
private:
// The Parser is subclassed and reinstantiated. Do not add additional
@ -624,6 +628,144 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
// Attribute parsing.
//===----------------------------------------------------------------------===//
namespace {
class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p, Type *eltTy)
: p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
ParseResult parse() { return parseList(shape); }
ArrayRef<char> getValues() const {
return {reinterpret_cast<const char *>(storage.data()), storage.size() * 8};
}
ArrayRef<int> getShape() const { return shape; }
private:
/// Parse either a single element or a list of elements. Return the dimensions
/// of the parsed sub-tensor in dims.
ParseResult parseElementOrList(llvm::SmallVectorImpl<int> &dims);
/// Parse a list of either lists or elements, returning the dimensions of the
/// parsed sub-tensors in dims. For example:
/// parseList([1, 2, 3]) -> Success, [3]
/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
/// parseList([[1, 2], 3]) -> Failure
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult parseList(llvm::SmallVectorImpl<int> &dims);
void addToStorage(uint64_t value) {
if (bitsWidth == 64)
storage.push_back(value);
if (currBitPos + bitsWidth > storage.size() * 64)
storage.push_back(0L);
auto *rawData = reinterpret_cast<char *>(storage.data());
DenseIntElementsAttr::writeBits(rawData, currBitPos, bitsWidth, value);
currBitPos += bitsWidth;
}
Parser &p;
Type *eltTy;
size_t currBitPos;
size_t bitsWidth;
SmallVector<int, 4> shape;
std::vector<uint64_t> storage;
};
} // namespace
/// Parse either a single element or a list of elements. Return the dimensions
/// of the parsed sub-tensor in dims.
ParseResult
TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
switch (p.getToken().getKind()) {
case Token::l_square:
return parseList(dims);
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto *result = p.parseAttribute();
if (!result)
return p.emitError("expected tensor element");
// check result matches the element type.
switch (eltTy->getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
if (!isa<FloatAttr>(result))
return p.emitError("expected tensor literal element has float type");
double value = cast<FloatAttr>(result)->getValue();
addToStorage(*(uint64_t *)(&value));
break;
}
case Type::Kind::Integer: {
if (!isa<IntegerAttr>(result))
return p.emitError("expected tensor literal element has integer type");
auto value = cast<IntegerAttr>(result)->getValue();
// If we couldn't successfully round trip the value, it means some bits
// are truncated and we should give up here.
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
if (apint.getSExtValue() != value)
return p.emitError("tensor literal element has more bits than that "
"specified in the type");
addToStorage((uint64_t)value);
break;
}
default:
return p.emitError("expected integer or float tensor element");
}
break;
}
default:
return p.emitError("expected '[' or scalar constant inside tensor literal");
}
return ParseSuccess;
}
/// Parse a list of either lists or elements, returning the dimensions of the
/// parsed sub-tensors in dims. For example:
/// parseList([1, 2, 3]) -> Success, [3]
/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
/// parseList([[1, 2], 3]) -> Failure
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl<int> &dims) {
p.consumeToken(Token::l_square);
auto checkDims = [&](const llvm::SmallVectorImpl<int> &prevDims,
const llvm::SmallVectorImpl<int> &newDims) {
if (prevDims == newDims)
return ParseSuccess;
return p.emitError("tensor literal is invalid; ranks are not consistent "
"between elements");
};
bool first = true;
llvm::SmallVector<int, 4> newDims;
unsigned size = 0;
auto parseCommaSeparatedList = [&]() {
llvm::SmallVector<int, 4> thisDims;
if (parseElementOrList(thisDims))
return ParseFailure;
++size;
if (!first)
return checkDims(newDims, thisDims);
newDims = thisDims;
first = false;
return ParseSuccess;
};
if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
return ParseFailure;
// Return the sublists' dimensions with 'size' prepended.
dims.clear();
dims.push_back(size);
dims.insert(dims.end(), newDims.begin(), newDims.end());
return ParseSuccess;
}
/// Given a parsed reference to a function name like @foo and a type that it
/// corresponds to, resolve it to a concrete function object (possibly
/// synthesizing a forward reference) or emit an error and return null on
@ -659,7 +801,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | function-id `:` function-type
/// | `splat<` (tensor-type | vector-type)`,`
/// | (`splat<` | `dense<`) (tensor-type | vector-type)`,`
/// attribute-value `>`
///
Attribute *Parser::parseAttribute() {
@ -757,24 +899,13 @@ Attribute *Parser::parseAttribute() {
case Token::kw_splat: {
consumeToken(Token::kw_splat);
if (parseToken(Token::less, "Expected '<' after 'elements'"))
if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr;
auto *type = dyn_cast<VectorOrTensorType>(parseType());
if (!type) {
return (
emitError("expected elements literal has a tensor or vector type"),
nullptr);
}
if (parseToken(Token::comma, "Expected ','"))
auto *type = parseVectorOrTensorType();
if (!type)
return nullptr;
if (!type->hasStaticShape() || type->getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
switch (getToken().getKind()) {
case Token::floatliteral:
case Token::integer:
@ -785,12 +916,32 @@ Attribute *Parser::parseAttribute() {
return builder.getSplatElementsAttr(type, scalar);
}
default:
return (
emitError("expected '[' or scalar constant inside tensor literal"),
nullptr);
return (emitError("expected scalar constant inside tensor literal"),
nullptr);
}
}
case Token::kw_dense: {
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
auto *type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
case Token::l_square: {
auto attr = parseDenseElementsAttr(type);
if (!attr)
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return attr;
}
default:
return (emitError("expected '[' to start dense tensor literal"), nullptr);
}
}
default: {
if (Type *type = parseType())
return builder.getTypeAttr(type);
@ -799,6 +950,42 @@ Attribute *Parser::parseAttribute() {
}
}
ElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
auto *eltTy = type->getElementType();
TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse())
return nullptr;
if (literalParser.getShape() != type->getShape()) {
std::string str;
llvm::raw_string_ostream s(str);
s << "inferred shape of elements literal ([";
interleaveComma(literalParser.getShape(), s);
s << "]) does not match type ([";
interleaveComma(type->getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
return builder.getDenseElementsAttr(type, literalParser.getValues());
}
VectorOrTensorType *Parser::parseVectorOrTensorType() {
auto *type = dyn_cast<VectorOrTensorType>(parseType());
if (!type) {
return (emitError("expected elements literal has a tensor or vector type"),
nullptr);
}
if (parseToken(Token::comma, "expected ','"))
return nullptr;
if (!type->hasStaticShape() || type->getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
return type;
}
/// Attribute dictionary.
///
/// attribute-dict ::= `{` `}`
@ -848,8 +1035,8 @@ enum AffineLowPrecOp {
Sub
};
/// Higher precedence ops - all at the same precedence level. HNoOp is false in
/// the boolean sense.
/// Higher precedence ops - all at the same precedence level. HNoOp is false
/// in the boolean sense.
enum AffineHighPrecOp {
/// Null value.
HNoOp,
@ -957,8 +1144,8 @@ AffineExpr AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op,
}
}
/// Consume this token if it is a lower precedence affine op (there are only two
/// precedence levels).
/// Consume this token if it is a lower precedence affine op (there are only
/// two precedence levels).
AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
switch (getToken().getKind()) {
case Token::plus:
@ -1103,8 +1290,8 @@ AffineExpr AffineParser::parseIntegerExpr() {
// Eg: for an expression without parentheses (like i + j + k + l), each
// of the four identifiers is an operand. For i + j*k + l, j*k is not an
// operand expression, it's an op expression and will be parsed via
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l
// are valid operands that will be parsed by this function.
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
// -l are valid operands that will be parsed by this function.
AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
switch (getToken().getKind()) {
case Token::bare_identifier:
@ -1148,13 +1335,13 @@ AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
///
/// llhs: the affine expression appearing on the left of the one being parsed.
/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned if
/// llhs is non-null; otherwise lhs is returned. This is to deal with left
/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
/// associativity.
///
/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3)
/// will be parsed using parseAffineHighPrecOpExpr().
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
AffineLowPrecOp llhsOp) {
AffineExpr lhs;
@ -1208,16 +1395,16 @@ AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
/// | bare-id
/// | integer-literal
///
/// Additional conditions are checked depending on the production. For eg., one
/// of the operands for `*` has to be either constant/symbolic; the second
/// Additional conditions are checked depending on the production. For eg.,
/// one of the operands for `*` has to be either constant/symbolic; the second
/// operand for floordiv, ceildiv, and mod has to be a positive integer.
AffineExpr AffineParser::parseAffineExpr() {
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
}
/// Parse a dim or symbol from the lists appearing before the actual expressions
/// of the affine map. Update our state to store the dimensional/symbolic
/// identifier.
/// Parse a dim or symbol from the lists appearing before the actual
/// expressions of the affine map. Update our state to store the
/// dimensional/symbolic identifier.
ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
if (getToken().isNot(Token::bare_identifier))
return emitError("expected bare identifier");
@ -1288,9 +1475,9 @@ AffineMap AffineParser::parseAffineMapInline() {
return res;
};
// Parse a multi-dimensional affine expression (a comma-separated list of 1-d
// affine expressions); the list cannot be empty.
// Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
// Parse a multi-dimensional affine expression (a comma-separated list of
// 1-d affine expressions); the list cannot be empty. Grammar:
// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
return AffineMap::Invalid();
@ -1357,8 +1544,8 @@ AffineMap Parser::parseAffineMapReference() {
//===----------------------------------------------------------------------===//
namespace {
/// This class contains parser state that is common across CFG and ML functions,
/// notably for dealing with operations and SSA values.
/// This class contains parser state that is common across CFG and ML
/// functions, notably for dealing with operations and SSA values.
class FunctionParser : public Parser {
public:
enum class Kind { CFGFunc, MLFunc };
@ -1371,15 +1558,15 @@ public:
/// This represents a use of an SSA value in the program. The first two
/// entries in the tuple are the name and result number of a reference. The
/// third is the location of the reference, which is used in case this ends up
/// being a use of an undefined value.
/// third is the location of the reference, which is used in case this ends
/// up being a use of an undefined value.
struct SSAUseInfo {
StringRef name; // Value name, e.g. %42 or %abc
unsigned number; // Number, specified with #12
SMLoc loc; // Location of first definition or use.
};
/// Given a reference to an SSA value and its type, return a reference. This
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
@ -1442,8 +1629,9 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
// Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain.
//
// We create these placeholders as having an empty name, which we know cannot
// be created through normal user input, allowing us to distinguish them.
// We create these placeholders as having an empty name, which we know
// cannot be created through normal user input, allowing us to distinguish
// them.
auto name = OperationName("placeholder", getContext());
auto *inst = OperationInst::create(getEncodedSourceLocation(loc), name,
/*operands=*/{}, type,
@ -1512,9 +1700,9 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) {
"previously defined here");
}
// If it was a forward reference, update everything that used it to use the
// actual definition instead, delete the forward ref, and remove it from our
// set of forward references we track.
// If it was a forward reference, update everything that used it to use
// the actual definition instead, delete the forward ref, and remove it
// from our set of forward references we track.
existing->replaceAllUsesWith(value);
existing->getDefiningInst()->destroy();
forwardReferencePlaceholders.erase(existing);
@ -1528,7 +1716,8 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) {
/// After the function is finished parsing, this function checks to see if
/// there are any remaining issues.
ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) {
// Check for any forward references that are left. If we find any, error out.
// Check for any forward references that are left. If we find any, error
// out.
if (!forwardReferencePlaceholders.empty()) {
SmallVector<std::pair<const char *, SSAValue *>, 4> errors;
// Iteration over the map isn't deterministic, so sort by source location.
@ -1825,9 +2014,9 @@ public:
return !(result = parser.parseType());
}
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
/// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name.
/// this captures the location of the attribute in 'loc' if it is non-null.
bool parseAttribute(Attribute *&result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute();
@ -1997,7 +2186,8 @@ Operation *FunctionParser::parseCustomOperation(
consumeToken();
// If the custom op parser crashes, produce some indication to help debugging.
// If the custom op parser crashes, produce some indication to help
// debugging.
std::string opNameStr = opName.str();
llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
opNameStr.c_str());
@ -2176,7 +2366,8 @@ ParseResult CFGFunctionParser::parseBasicBlock() {
if (parseToken(Token::colon, "expected ':' after basic block name"))
return ParseFailure;
// Set the insertion point to the block we want to insert new operations into.
// Set the insertion point to the block we want to insert new operations
// into.
builder.setInsertionPoint(block);
auto createOpFunc = [&](const OperationState &result) -> Operation * {
@ -2218,7 +2409,8 @@ ParseResult CFGFunctionParser::parseBranchBlockAndUseList(
/// terminator-stmt ::= `br` bb-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens
/// terminator-stmt ::=
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id
/// branch-use-list?
/// terminator-stmt ::= `return` ssa-use-and-type-list?
///
TerminatorInst *CFGFunctionParser::parseTerminator() {
@ -2471,9 +2663,9 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
// Loop bound.
///
/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound
/// upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound
/// shorthand-bound ::= ssa-id | `-`? integer-literal
/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list |
/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list
/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal
///
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
AffineMap &map, bool isLower) {
@ -2532,8 +2724,8 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
/// affine-constraint ::= affine-expr `>=` `0`
/// | affine-expr `==` `0`
///
/// isEq is set to true if the parsed constraint is an equality, false if it is
/// an inequality (greater than or equal).
/// isEq is set to true if the parsed constraint is an equality, false if it
/// is an inequality (greater than or equal).
///
AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
AffineExpr expr = parseAffineExpr();
@ -2568,9 +2760,11 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
/// Parse an integer set definition.
/// integer-set-inline
/// ::= dim-and-symbol-id-lists `:` affine-constraint-conjunction
/// ::= dim-and-symbol-id-lists `:`
/// affine-constraint-conjunction
/// affine-constraint-conjunction ::= /*empty*/
/// | affine-constraint (`,` affine-constraint)*
/// | affine-constraint (`,`
/// affine-constraint)*
///
IntegerSet AffineParser::parseIntegerSetInline() {
unsigned numDims = 0, numSymbols = 0;
@ -2859,11 +3053,12 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}
/// Parse a function signature, starting with a name and including the parameter
/// list.
/// Parse a function signature, starting with a name and including the
/// parameter list.
///
/// argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list
/// function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
/// function-signature ::= function-id `(` argument-list `)` (`->`
/// type-list)?
///
ParseResult
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
@ -2963,7 +3158,8 @@ ParseResult ModuleParser::parseCFGFunc() {
return ParseFailure;
}
// Okay, the CFG function signature was parsed correctly, create the function.
// Okay, the CFG function signature was parsed correctly, create the
// function.
auto *function =
new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
@ -2979,7 +3175,8 @@ ParseResult ModuleParser::parseCFGFunc() {
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature
/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt `}`
/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt
/// `}`
///
ParseResult ModuleParser::parseMLFunc() {
consumeToken(Token::kw_mlfunc);
@ -2997,7 +3194,8 @@ ParseResult ModuleParser::parseMLFunc() {
return ParseFailure;
}
// Okay, the ML function signature was parsed correctly, create the function.
// Okay, the ML function signature was parsed correctly, create the
// function.
auto *function =
MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
@ -3019,9 +3217,9 @@ ParseResult ModuleParser::parseMLFunc() {
return parser.parseFunctionBody();
}
/// Given an attribute that could refer to a function attribute in the remapping
/// table, walk it and rewrite it to use the mapped function. If it doesn't
/// refer to anything in the table, then it is returned unmodified.
/// Given an attribute that could refer to a function attribute in the
/// remapping table, walk it and rewrite it to use the mapped function. If it
/// doesn't refer to anything in the table, then it is returned unmodified.
static Attribute *
remapFunctionAttrs(Attribute *input,
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
@ -3097,8 +3295,8 @@ ParseResult ModuleParser::finalizeModule() {
if (remappingTable.empty())
return ParseSuccess;
// Otherwise, walk the entire module replacing uses of one attribute set with
// the correct ones.
// Otherwise, walk the entire module replacing uses of one attribute set
// with the correct ones.
for (auto &fn : *getModule()) {
if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
for (auto &bb : *cfgFn) {
@ -3147,8 +3345,8 @@ ParseResult ModuleParser::parseModule() {
return finalizeModule();
// If we got an error token, then the lexer already emitted an error, just
// stop. Someday we could introduce error recovery if there was demand for
// it.
// stop. Someday we could introduce error recovery if there was demand
// for it.
case Token::error:
return ParseFailure;
@ -3183,7 +3381,8 @@ ParseResult ModuleParser::parseModule() {
//===----------------------------------------------------------------------===//
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
/// MLIR module if it was valid. If not, it emits diagnostics and returns
/// null.
Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
MLIRContext *context) {
@ -3195,16 +3394,16 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
return nullptr;
}
// Make sure the parse module has no other structural problems detected by the
// verifier.
// Make sure the parse module has no other structural problems detected by
// the verifier.
if (module->verify())
return nullptr;
return module.release();
}
/// This parses the program string to a MLIR module if it was valid. If not, it
/// emits diagnostics and returns null.
/// This parses the program string to a MLIR module if it was valid. If not,
/// it emits diagnostics and returns null.
Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
if (!memBuffer)

View File

@ -93,6 +93,7 @@ TOK_KEYWORD(br)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(cfgfunc)
TOK_KEYWORD(cond_br)
TOK_KEYWORD(dense)
TOK_KEYWORD(else)
TOK_KEYWORD(splat)
TOK_KEYWORD(extfunc)

View File

@ -628,3 +628,64 @@ mlfunc @calls(%arg0 : i32) {
// expected-error@+2 {{expected SSA operand}}
cfgfunc@n(){b(
// -----
cfgfunc @elementsattr_non_tensor_type() -> () {
bb0:
"foo"(){bar: dense<i32, [4]>} : () -> () // expected-error {{expected elements literal has a tensor or vector type}}
}
// -----
cfgfunc @elementsattr_non_ranked() -> () {
bb0:
"foo"(){bar: dense<tensor<?xi32>, [4]>} : () -> () // expected-error {{tensor literals must be ranked and have static shape}}
}
// -----
cfgfunc @elementsattr_shape_mismatch() -> () {
bb0:
"foo"(){bar: dense<tensor<5xi32>, [4]>} : () -> () // expected-error {{inferred shape of elements literal ([1]) does not match type ([5])}}
}
// -----
cfgfunc @elementsattr_invalid() -> () {
bb0:
"foo"(){bar: dense<tensor<2xi32>, [4, [5]]>} : () -> () // expected-error {{tensor literal is invalid; ranks are not consistent between elements}}
}
// -----
cfgfunc @elementsattr_badtoken() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi32>, [tf_opaque]>} : () -> () // expected-error {{expected '[' or scalar constant inside tensor literal}}
}
// -----
cfgfunc @elementsattr_floattype1() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi32>, [4.0]>} : () -> () // expected-error {{expected tensor literal element has integer type}}
}
// -----
cfgfunc @elementsattr_floattype2() -> () {
bb0:
"foo"(){bar: dense<tensor<1xf32>, [4]>} : () -> () // expected-error {{expected tensor literal element has float type}}
}
// -----
cfgfunc @elementsattr_toolarge1() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi8>, [777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
}
// -----
cfgfunc @elementsattr_toolarge2() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi8>, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
}

View File

@ -485,8 +485,8 @@ mlfunc @mlfuncsimplemap(%arg0 : index, %arg1 : index) -> () {
return
}
// CHECK-LABEL: cfgfunc @tensorattr
cfgfunc @tensorattr() -> () {
// CHECK-LABEL: cfgfunc @splattensorattr
cfgfunc @splattensorattr() -> () {
bb0:
// CHECK: "splatIntTensor"() {bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
"splatIntTensor"(){bar: splat<tensor<2x1x4xi32>, 5>} : () -> ()
@ -498,3 +498,96 @@ bb0:
"splatFloatVector"(){bar: splat<vector<2x1x4xf16>, -5.0>} : () -> ()
return
}
// CHECK-LABEL: cfgfunc @densetensorattr
cfgfunc @densetensorattr() -> () {
bb0:
// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck.
// CHECK: "fooi3"() {bar: dense<tensor<2x1x4xi3>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi3"(){bar: dense<tensor<2x1x4xi3>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi6"() {bar: dense<tensor<2x1x4xi6>, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> ()
"fooi6"(){bar: dense<tensor<2x1x4xi6>, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> ()
// CHECK: "fooi8"() {bar: dense<tensor<1x1x1xi8>, {{\[\[\[}}5]]]>} : () -> ()
"fooi8"(){bar: dense<tensor<1x1x1xi8>, [[[5]]]>} : () -> ()
// CHECK: "fooi13"() {bar: dense<tensor<2x1x4xi13>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi13"(){bar: dense<tensor<2x1x4xi13>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi16"() {bar: dense<tensor<1x1x1xi16>, {{\[\[\[}}-5]]]>} : () -> ()
"fooi16"(){bar: dense<tensor<1x1x1xi16>, [[[-5]]]>} : () -> ()
// CHECK: "fooi23"() {bar: dense<tensor<2x1x4xi23>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi23"(){bar: dense<tensor<2x1x4xi23>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi32"() {bar: dense<tensor<1x1x1xi32>, {{\[\[\[}}5]]]>} : () -> ()
"fooi32"(){bar: dense<tensor<1x1x1xi32>, [[[5]]]>} : () -> ()
// CHECK: "fooi33"() {bar: dense<tensor<2x1x4xi33>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi33"(){bar: dense<tensor<2x1x4xi33>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi43"() {bar: dense<tensor<2x1x4xi43>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi43"(){bar: dense<tensor<2x1x4xi43>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi53"() {bar: dense<tensor<2x1x4xi53>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> ()
"fooi53"(){bar: dense<tensor<2x1x4xi53>, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> ()
// CHECK: "fooi64"() {bar: dense<tensor<2x1x4xi64>, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 3, -1, 2]]]>} : () -> ()
"fooi64"(){bar: dense<tensor<2x1x4xi64>, [[[1, -2, 1, 2]], [[0, 3, -1, 2]]]>} : () -> ()
// CHECK: "fooi64"() {bar: dense<tensor<1x1x1xi64>, {{\[\[\[}}-5]]]>} : () -> ()
"fooi64"(){bar: dense<tensor<1x1x1xi64>, [[[-5]]]>} : () -> ()
// CHECK: "foo2"() {bar: dense<tensor<0xi32>, []>} : () -> ()
"foo2"(){bar: dense<tensor<0 x i32>, []>} : () -> ()
// CHECK: "foo2"() {bar: dense<tensor<1x0xi32>, {{\[\[}}]]>} : () -> ()
"foo2"(){bar: dense<tensor<1x0 x i32>, [[]]>} : () -> ()
// CHECK: "foo3"() {bar: dense<tensor<2x1x4xi32>, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> ()
"foo3"(){bar: dense<tensor<2x1x4xi32>, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> ()
// CHECK: "float1"() {bar: dense<tensor<1x1x1xf32>, {{\[\[\[}}5.000000e+00]]]>} : () -> ()
"float1"(){bar: dense<tensor<1x1x1xf32>, [[[5.0]]]>} : () -> ()
// CHECK: "float2"() {bar: dense<tensor<0xf32>, []>} : () -> ()
"float2"(){bar: dense<tensor<0 x f32>, []>} : () -> ()
// CHECK: "float2"() {bar: dense<tensor<1x0xf32>, {{\[\[}}]]>} : () -> ()
"float2"(){bar: dense<tensor<1x0 x f32>, [[]]>} : () -> ()
// CHECK: "bfloat16"() {bar: dense<tensor<2x1x4xbf16>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"bfloat16"(){bar: dense<tensor<2x1x4xbf16>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float16"() {bar: dense<tensor<2x1x4xf16>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float16"(){bar: dense<tensor<2x1x4xf16>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float32"() {bar: dense<tensor<2x1x4xf32>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float32"(){bar: dense<tensor<2x1x4xf32>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float64"() {bar: dense<tensor<2x1x4xf64>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float64"(){bar: dense<tensor<2x1x4xf64>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
return
}
// CHECK-LABEL: cfgfunc @densevectorattr
cfgfunc @densevectorattr() -> () {
bb0:
// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck.
// CHECK: "fooi8"() {bar: dense<vector<1x1x1xi8>, {{\[\[\[}}5]]]>} : () -> ()
"fooi8"(){bar: dense<vector<1x1x1xi8>, [[[5]]]>} : () -> ()
// CHECK: "fooi16"() {bar: dense<vector<1x1x1xi16>, {{\[\[\[}}-5]]]>} : () -> ()
"fooi16"(){bar: dense<vector<1x1x1xi16>, [[[-5]]]>} : () -> ()
// CHECK: "foo32"() {bar: dense<vector<1x1x1xi32>, {{\[\[\[}}5]]]>} : () -> ()
"foo32"(){bar: dense<vector<1x1x1xi32>, [[[5]]]>} : () -> ()
// CHECK: "fooi64"() {bar: dense<vector<1x1x1xi64>, {{\[\[\[}}-5]]]>} : () -> ()
"fooi64"(){bar: dense<vector<1x1x1xi64>, [[[-5]]]>} : () -> ()
// CHECK: "foo2"() {bar: dense<vector<0xi32>, []>} : () -> ()
"foo2"(){bar: dense<vector<0 x i32>, []>} : () -> ()
// CHECK: "foo2"() {bar: dense<vector<1x0xi32>, {{\[\[}}]]>} : () -> ()
"foo2"(){bar: dense<vector<1x0 x i32>, [[]]>} : () -> ()
// CHECK: "foo3"() {bar: dense<vector<2x1x4xi32>, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> ()
"foo3"(){bar: dense<vector<2x1x4xi32>, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> ()
// CHECK: "float1"() {bar: dense<vector<1x1x1xf32>, {{\[\[\[}}5.000000e+00]]]>} : () -> ()
"float1"(){bar: dense<vector<1x1x1xf32>, [[[5.0]]]>} : () -> ()
// CHECK: "float2"() {bar: dense<vector<0xf32>, []>} : () -> ()
"float2"(){bar: dense<vector<0 x f32>, []>} : () -> ()
// CHECK: "float2"() {bar: dense<vector<1x0xf32>, {{\[\[}}]]>} : () -> ()
"float2"(){bar: dense<vector<1x0 x f32>, [[]]>} : () -> ()
// CHECK: "bfloat16"() {bar: dense<vector<2x1x4xbf16>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"bfloat16"(){bar: dense<vector<2x1x4xbf16>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float16"() {bar: dense<vector<2x1x4xf16>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float16"(){bar: dense<vector<2x1x4xf16>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float32"() {bar: dense<vector<2x1x4xf32>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float32"(){bar: dense<vector<2x1x4xf32>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
// CHECK: "float64"() {bar: dense<vector<2x1x4xf64>, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> ()
"float64"(){bar: dense<vector<2x1x4xf64>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
return
}