Add support to constant sparse tensor / vector attribute

The SparseElementsAttr uses (COO) Coordinate List encoding to represents a
sparse tensor / vector. Specifically, the coordinates and values are stored as
two dense elements attributes. The first dense elements attribute is a 2-D
attribute with shape [N, ndims], which contains the indices of the elements
with nonzero values in the constant vector/tensor. The second elements
attribute is a 1-D attribute list with shape [N], which supplies the values for
each element in the first elements attribute. ndims is the rank of the
vector/tensor and N is the total nonzero elements.

The syntax is:

`sparse<` (tensor-type | vector-type)`, ` indices-attribute-list, values-attribute-list `>`

Example: a sparse tensor

sparse<vector<3x4xi32>, [[0, 0], [1, 2]], [1, 2]> represents the dense tensor

[[1, 0, 0, 0]
 [0, 0, 2, 0]
 [0, 0, 0, 0]]

PiperOrigin-RevId: 217764319
This commit is contained in:
Feng Liu 2018-10-18 14:02:20 -07:00 committed by jpienaar
parent b5b90e5465
commit 03b48999b6
8 changed files with 260 additions and 12 deletions

View File

@ -48,8 +48,9 @@ public:
SplatElements,
DenseIntElements,
DenseFPElements,
SparseElements,
FIRST_ELEMENTS_ATTR = SplatElements,
LAST_ELEMENTS_ATTR = DenseFPElements,
LAST_ELEMENTS_ATTR = SparseElements,
};
/// Return the classification for this attribute.
@ -271,7 +272,7 @@ private:
/// meaning all of the elements have the same value.
class SplatElementsAttr : public ElementsAttr {
public:
static ElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
static SplatElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
Attribute *getValue() const { return elt; }
/// Method for support type inquiry through isa, cast and dyn_cast.
@ -333,7 +334,7 @@ public:
// TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
APInt getValue(ArrayRef<int> indices) const;
APInt getValue(ArrayRef<unsigned> indices) const;
/// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos`
/// in array `rawData`.
@ -366,7 +367,7 @@ public:
// TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
APFloat getValue(ArrayRef<int> indices) const;
APFloat getValue(ArrayRef<unsigned> indices) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
@ -376,6 +377,50 @@ public:
private:
~DenseFPElementsAttr() = delete;
};
/// An attribute represents a reference to a sparse vector or tensor object.
///
/// This class uses COO (coordinate list) encoding to represent the sparse
/// elements in an element attribute. Specifically, the sparse vector/tensor
/// stores the indices and values as two separate dense elements attributes. The
/// dense elements attribute indices is a 2-D tensor with shape [N, ndims],
/// which specifies the indices of the elements in the sparse tensor that
/// contains nonzero values. The dense elements attribute values is a 1-D tensor
/// with shape [N], and it supplies the corresponding values for the indices.
///
/// For example,
/// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor
/// [[1, 0, 0, 0],
/// [0, 0, 5, 0],
/// [0, 0, 0, 0]].
class SparseElementsAttr : public ElementsAttr {
public:
static SparseElementsAttr *get(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values);
DenseIntElementsAttr *getIndices() const { return indices; }
DenseElementsAttr *getValues() const { return values; }
/// Return the value at the given index.
Attribute *getValue(ArrayRef<unsigned> index) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::SparseElements;
}
private:
SparseElementsAttr(VectorOrTensorType *type, DenseIntElementsAttr *indices,
DenseElementsAttr *values)
: ElementsAttr(Kind::SparseElements, type), indices(indices),
values(values) {}
~SparseElementsAttr() = delete;
DenseIntElementsAttr *const indices;
DenseElementsAttr *const values;
};
} // end namespace mlir.
#endif

View File

@ -44,6 +44,8 @@ class TypeAttr;
class ArrayAttr;
class FunctionAttr;
class ElementsAttr;
class DenseElementsAttr;
class DenseIntElementsAttr;
class AffineMapAttr;
class AffineMap;
@ -102,6 +104,9 @@ public:
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data);
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indicies,
DenseElementsAttr *values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
@ -264,8 +269,7 @@ public:
}
private:
template <typename T>
T *insertTerminator(T *term) {
template <typename T> T *insertTerminator(T *term) {
block->setTerminator(term);
return term;
}

View File

@ -477,6 +477,17 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
os << '>';
break;
}
case Attribute::Kind::SparseElements: {
auto *elementsAttr = cast<SparseElementsAttr>(attr);
os << "sparse<";
printType(elementsAttr->getType());
os << ", ";
printDenseElementsAttr(elementsAttr->getIndices());
os << ", ";
printDenseElementsAttr(elementsAttr->getValues());
os << '>';
break;
}
}
}

View File

@ -154,6 +154,12 @@ ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
return DenseElementsAttr::get(type, data);
}
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indicies,
DenseElementsAttr *values) {
return SparseElementsAttr::get(type, indicies, values);
}
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integet Sets.
//===----------------------------------------------------------------------===//

View File

@ -297,6 +297,9 @@ public:
using DenseElementsAttrSet =
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
DenseElementsAttrSet denseElementsAttrs;
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
SparseElementsAttr *>
sparseElementsAttrs;
public:
MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {}
@ -951,7 +954,8 @@ void DenseFPElementsAttr::getValues(
}
}
ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) {
SplatElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type,
Attribute *elt) {
auto &impl = type->getContext()->getImpl();
// Look to see if we already have this.
@ -968,6 +972,26 @@ ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) {
return result;
}
SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values) {
auto &impl = type->getContext()->getImpl();
// Look to see if we already have this.
auto key = std::make_tuple(type, indices, values);
auto *&result = impl.sparseElementsAttrs[key];
// If we already have it, return that value.
if (result)
return result;
// Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SparseElementsAttr>();
new (result) SparseElementsAttr(type, indices, values);
return result;
}
//===----------------------------------------------------------------------===//
// AffineMap and AffineExpr uniquing
//===----------------------------------------------------------------------===//

View File

@ -200,6 +200,7 @@ public:
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type);
Attribute *parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
// Polyhedral structures.
@ -207,7 +208,8 @@ public:
AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference();
ElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr *parseDenseElementsAttr(Type *eltType, bool isVector);
VectorOrTensorType *parseVectorOrTensorType();
private:
@ -803,6 +805,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | function-id `:` function-type
/// | (`splat<` | `dense<`) (tensor-type | vector-type)`,`
/// attribute-value `>`
/// | `sparse<` (tensor-type | vector-type)`,`
/// attribute-value`, ` attribute-value `>`
///
Attribute *Parser::parseAttribute() {
switch (getToken().getKind()) {
@ -905,7 +909,6 @@ Attribute *Parser::parseAttribute() {
auto *type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
case Token::floatliteral:
case Token::integer:
@ -942,6 +945,64 @@ Attribute *Parser::parseAttribute() {
return (emitError("expected '[' to start dense tensor literal"), nullptr);
}
}
case Token::kw_sparse: {
consumeToken(Token::kw_sparse);
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
auto *type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
case Token::l_square: {
/// Parse indices
auto *indicesEltType = builder.getIntegerType(32);
auto *indices =
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
if (parseToken(Token::comma, "expected ','"))
return nullptr;
/// Parse values.
auto *valuesEltType = type->getElementType();
auto *values =
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
/// Sanity check.
auto *indicesType = indices->getType();
auto *valuesType = values->getType();
auto sameShape = (indicesType->getRank() == 1) ||
(type->getRank() == indicesType->getDimSize(1));
auto sameElementNum =
indicesType->getDimSize(0) == valuesType->getDimSize(0);
if (!sameShape || !sameElementNum) {
std::string str;
llvm::raw_string_ostream s(str);
s << "expected shape ([";
interleaveComma(type->getShape(), s);
s << "]); inferred shape of indices literal ([";
interleaveComma(indicesType->getShape(), s);
s << "]); inferred shape of values literal ([";
interleaveComma(valuesType->getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
// Build the sparse elements attribute by the indices and values.
return builder.getSparseElementsAttr(
type, cast<DenseIntElementsAttr>(indices), values);
}
default:
return (emitError("expected '[' to start sparse tensor literal"),
nullptr);
}
return (emitError("expected elements literal has a tensor or vector type"),
nullptr);
}
default: {
if (Type *type = parseType())
return builder.getTypeAttr(type);
@ -950,7 +1011,42 @@ Attribute *Parser::parseAttribute() {
}
}
ElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
/// Dense elements attribute.
///
/// dense-attr-list ::= `[` attribute-value `]`
/// attribute-value ::= integer-literal
/// | float-literal
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
///
/// This method returns a constructed dense elements attribute with the shape
/// from the parsing result.
DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
bool isVector) {
TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse())
return nullptr;
VectorOrTensorType *type;
if (isVector) {
type = builder.getVectorType(literalParser.getShape(), eltType);
} else {
type = builder.getTensorType(literalParser.getShape(), eltType);
}
return (DenseElementsAttr *)builder.getDenseElementsAttr(
type, literalParser.getValues());
}
/// Dense elements attribute.
///
/// dense-attr-list ::= `[` attribute-value `]`
/// attribute-value ::= integer-literal
/// | float-literal
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
///
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
auto *eltTy = type->getElementType();
TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse())
@ -965,9 +1061,15 @@ ElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
s << "])";
return (emitError(s.str()), nullptr);
}
return builder.getDenseElementsAttr(type, literalParser.getValues());
return (DenseElementsAttr *)builder.getDenseElementsAttr(
type, literalParser.getValues());
}
/// Vector or tensor type for elements attribute.
///
/// vector-or-tensor-type ::= vector-type | tensor-type
///
/// This method also checks the type has static shape and ranked.
VectorOrTensorType *Parser::parseVectorOrTensorType() {
auto *type = dyn_cast<VectorOrTensorType>(parseType());
if (!type) {
@ -982,7 +1084,6 @@ VectorOrTensorType *Parser::parseVectorOrTensorType() {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
return type;
}

View File

@ -123,6 +123,7 @@ TOK_KEYWORD(tf_string)
TOK_KEYWORD(tf_f32ref)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
TOK_KEYWORD(sparse)
TOK_KEYWORD(vector)
#undef TOK_MARKER

View File

@ -591,3 +591,59 @@ bb0:
"float64"(){bar: dense<vector<2x1x4xf64>, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> ()
return
}
// CHECK-LABEL: cfgfunc @sparsetensorattr
cfgfunc @sparsetensorattr() -> () {
bb0:
// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck.
// CHECK: "fooi8"() {bar: sparse<tensor<1x1x1xi8>, {{\[\[}}0, 0, 0]], {{\[}}-2]>} : () -> ()
"fooi8"(){bar: sparse<tensor<1x1x1xi8>, [[0, 0, 0]], [-2]>} : () -> ()
// CHECK: "fooi16"() {bar: sparse<tensor<2x2x2xi16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]>} : () -> ()
"fooi16"(){bar: sparse<tensor<2x2x2xi16>, [[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]>} : () -> ()
// CHECK: "fooi32"() {bar: sparse<tensor<1x1xi32>, {{\[}}], {{\[}}]>} : () -> ()
"fooi32"(){bar: sparse<tensor<1x1xi32>, [], []>} : () -> ()
// CHECK: "fooi64"() {bar: sparse<tensor<1xi64>, {{\[\[}}0]], {{\[}}-1]>} : () -> ()
"fooi64"(){bar: sparse<tensor<1xi64>, [[0]], [-1]>} : () -> ()
// CHECK: "foo2"() {bar: sparse<tensor<0xi32>, {{\[}}], {{\[}}]>} : () -> ()
"foo2"(){bar: sparse<tensor<0 x i32>, [], []>} : () -> ()
// CHECK: "foof16"() {bar: sparse<tensor<1x1x1xf16>, {{\[\[}}0, 0, 0]], {{\[}}-2.000000e+00]>} : () -> ()
"foof16"(){bar: sparse<tensor<1x1x1xf16>, [[0, 0, 0]], [-2.0]>} : () -> ()
// CHECK: "foobf16"() {bar: sparse<tensor<2x2x2xbf16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]>} : () -> ()
"foobf16"(){bar: sparse<tensor<2x2x2xbf16>, [[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]>} : () -> ()
// CHECK: "foof32"() {bar: sparse<tensor<1x1xf32>, {{\[}}], {{\[}}]>} : () -> ()
"foof32"(){bar: sparse<tensor<1x0x1xf32>, [], []>} : () -> ()
// CHECK: "foof64"() {bar: sparse<tensor<1xf64>, {{\[\[}}0]], {{\[}}-1.000000e+00]>} : () -> ()
"foof64"(){bar: sparse<tensor<1xf64>, [[0]], [-1.0]>} : () -> ()
// CHECK: "foof320"() {bar: sparse<tensor<0xf32>, {{\[}}], {{\[}}]>} : () -> ()
"foof320"(){bar: sparse<tensor<0 x f32>, [], []>} : () -> ()
return
}
// CHECK-LABEL: cfgfunc @sparsevectorattr
cfgfunc @sparsevectorattr() -> () {
bb0:
// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck.
// CHECK: "fooi8"() {bar: sparse<vector<1x1x1xi8>, {{\[\[}}0, 0, 0]], {{\[}}-2]>} : () -> ()
"fooi8"(){bar: sparse<vector<1x1x1xi8>, [[0, 0, 0]], [-2]>} : () -> ()
// CHECK: "fooi16"() {bar: sparse<vector<2x2x2xi16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]>} : () -> ()
"fooi16"(){bar: sparse<vector<2x2x2xi16>, [[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]>} : () -> ()
// CHECK: "fooi32"() {bar: sparse<vector<1x1xi32>, {{\[}}], {{\[}}]>} : () -> ()
"fooi32"(){bar: sparse<vector<1x1xi32>, [], []>} : () -> ()
// CHECK: "fooi64"() {bar: sparse<vector<1xi64>, {{\[\[}}0]], {{\[}}-1]>} : () -> ()
"fooi64"(){bar: sparse<vector<1xi64>, [[0]], [-1]>} : () -> ()
// CHECK: "foo2"() {bar: sparse<vector<0xi32>, {{\[}}], {{\[}}]>} : () -> ()
"foo2"(){bar: sparse<vector<0 x i32>, [], []>} : () -> ()
// CHECK: "foof16"() {bar: sparse<vector<1x1x1xf16>, {{\[\[}}0, 0, 0]], {{\[}}-2.000000e+00]>} : () -> ()
"foof16"(){bar: sparse<vector<1x1x1xf16>, [[0, 0, 0]], [-2.0]>} : () -> ()
// CHECK: "foobf16"() {bar: sparse<vector<2x2x2xbf16>, {{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]>} : () -> ()
"foobf16"(){bar: sparse<vector<2x2x2xbf16>, [[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]>} : () -> ()
// CHECK: "foof32"() {bar: sparse<vector<1x1xf32>, {{\[}}], {{\[}}]>} : () -> ()
"foof32"(){bar: sparse<vector<1x0x1xf32>, [], []>} : () -> ()
// CHECK: "foof64"() {bar: sparse<vector<1xf64>, {{\[\[}}0]], {{\[}}-1.000000e+00]>} : () -> ()
"foof64"(){bar: sparse<vector<1xf64>, [[0]], [-1.0]>} : () -> ()
// CHECK: "foof320"() {bar: sparse<vector<0xf32>, {{\[}}], {{\[}}]>} : () -> ()
"foof320"(){bar: sparse<vector<0 x f32>, [], []>} : () -> ()
return
}