Add support to opaque elements attributes

For some of the constant vector / tesor, if the compiler doesn't need to
interpret their elements content, they can be stored in this class to save the
serialize / deserialize cost.

syntax:

`opaque<` tensor-type `,` opaque-string `>`

opaque-string ::= `0x` [0-9a-fA-F]*
PiperOrigin-RevId: 218399426
This commit is contained in:
Feng Liu 2018-10-23 13:44:04 -07:00 committed by jpienaar
parent 301f83f906
commit 3d7ab2d265
9 changed files with 142 additions and 12 deletions

View File

@ -49,6 +49,7 @@ public:
SplatElements,
DenseIntElements,
DenseFPElements,
OpaqueElements,
SparseElements,
FIRST_ELEMENTS_ATTR = SplatElements,
LAST_ELEMENTS_ATTR = SparseElements,
@ -335,11 +336,6 @@ private:
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
size_t bitsWidth)
: DenseElementsAttr(Kind::DenseIntElements, type, data),
bitsWidth(bitsWidth) {}
// TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
@ -361,6 +357,12 @@ public:
}
private:
friend class DenseElementsAttr;
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
size_t bitsWidth)
: DenseElementsAttr(Kind::DenseIntElements, type, data),
bitsWidth(bitsWidth) {}
~DenseIntElementsAttr() = delete;
size_t bitsWidth;
@ -370,9 +372,6 @@ private:
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
// TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
@ -384,9 +383,33 @@ public:
}
private:
friend class DenseElementsAttr;
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
~DenseFPElementsAttr() = delete;
};
/// An attribute represents a reference to a tensor constant with opaque
/// content. This respresentation is for tensor constants which the compiler
/// doesn't need to interpret.
class OpaqueElementsAttr : public ElementsAttr {
public:
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes);
StringRef getValue() const { return bytes; }
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::OpaqueElements;
}
private:
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
~OpaqueElementsAttr() = delete;
StringRef bytes;
};
/// An attribute represents a reference to a sparse vector or tensor object.
///
/// This class uses COO (coordinate list) encoding to represent the sparse

View File

@ -106,8 +106,10 @@ public:
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data);
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indicies,
DenseIntElementsAttr *indices,
DenseElementsAttr *values);
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);

View File

@ -443,6 +443,14 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
}
break;
}
case Attribute::Kind::OpaqueElements: {
auto *eltsAttr = cast<OpaqueElementsAttr>(attr);
os << "opaque<";
printType(eltsAttr->getType());
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"'
<< '>';
break;
}
case Attribute::Kind::DenseIntElements:
case Attribute::Kind::DenseFPElements: {
auto *eltsAttr = cast<DenseElementsAttr>(attr);

View File

@ -159,9 +159,14 @@ ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
}
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indicies,
DenseIntElementsAttr *indices,
DenseElementsAttr *values) {
return SparseElementsAttr::get(type, indicies, values);
return SparseElementsAttr::get(type, indices, values);
}
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
}
//===----------------------------------------------------------------------===//

View File

@ -211,6 +211,23 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
}
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> {
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttr *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_pair(rhs->getType(), rhs->getValue());
}
};
} // end anonymous namespace.
namespace mlir {
@ -316,6 +333,9 @@ public:
using DenseElementsAttrSet =
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
DenseElementsAttrSet denseElementsAttrs;
using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs;
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
SparseElementsAttr *>
sparseElementsAttrs;
@ -888,6 +908,28 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type,
StringRef bytes) {
assert(isValidTensorElementType(type->getElementType()) &&
"Input element type should be a valid tensor element type");
auto &impl = type->getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value.
if (!existing.second)
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>();
bytes = bytes.copy(impl.allocator);
new (result) OpaqueElementsAttr(type, bytes);
return *existing.first = result;
}
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();

View File

@ -41,6 +41,7 @@
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include <algorithm>
using namespace mlir;
using llvm::MemoryBuffer;
@ -895,7 +896,26 @@ Attribute *Parser::parseAttribute() {
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
return function ? builder.getFunctionAttr(function) : nullptr;
}
case Token::kw_opaque: {
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
auto *type = parseVectorOrTensorType();
if (!type)
return nullptr;
auto val = getToken().getStringValue();
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
return (emitError("opaque string should start with '0x'"), nullptr);
val = val.substr(2);
if (!std::all_of(val.begin(), val.end(),
[](char c) { return llvm::isHexDigit(c); })) {
return (emitError("opaque string only contains hex digits"), nullptr);
}
consumeToken(Token::string);
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getOpaqueElementsAttr(type, llvm::fromHex(val));
}
case Token::kw_splat: {
consumeToken(Token::kw_splat);
if (parseToken(Token::less, "expected '<' after 'splat'"))

View File

@ -110,6 +110,7 @@ TOK_KEYWORD(memref)
TOK_KEYWORD(min)
TOK_KEYWORD(mlfunc)
TOK_KEYWORD(mod)
TOK_KEYWORD(opaque)
TOK_KEYWORD(return)
TOK_KEYWORD(size)
TOK_KEYWORD(step)

View File

@ -699,3 +699,17 @@ cfgfunc @elementsattr_toolarge2() -> () {
bb0:
"foo"(){bar: dense<tensor<1xi8>, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
}
// -----
cfgfunc @elementsattr_malformed_opaque() -> () {
bb0:
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}
}
// -----
cfgfunc @elementsattr_malformed_opaque1() -> () {
bb0:
"foo"(){bar: opaque<tensor<1xi8>, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}}
}

View File

@ -499,6 +499,21 @@ bb0:
return
}
// CHECK-LABEL: cfgfunc @opaquetensorattr
cfgfunc @opaquetensorattr() -> () {
bb0:
// CHECK: "opaqueIntTensor"() {bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
"opaqueIntTensor"(){bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueFloatTensor"() {bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
"opaqueFloatTensor"(){bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
return
}
// CHECK-LABEL: cfgfunc @densetensorattr
cfgfunc @densetensorattr() -> () {
bb0: