forked from OSchip/llvm-project
Add support to opaque elements attributes
For some of the constant vector / tesor, if the compiler doesn't need to interpret their elements content, they can be stored in this class to save the serialize / deserialize cost. syntax: `opaque<` tensor-type `,` opaque-string `>` opaque-string ::= `0x` [0-9a-fA-F]* PiperOrigin-RevId: 218399426
This commit is contained in:
parent
301f83f906
commit
3d7ab2d265
|
@ -49,6 +49,7 @@ public:
|
|||
SplatElements,
|
||||
DenseIntElements,
|
||||
DenseFPElements,
|
||||
OpaqueElements,
|
||||
SparseElements,
|
||||
FIRST_ELEMENTS_ATTR = SplatElements,
|
||||
LAST_ELEMENTS_ATTR = SparseElements,
|
||||
|
@ -335,11 +336,6 @@ private:
|
|||
/// object.
|
||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
|
||||
size_t bitsWidth)
|
||||
: DenseElementsAttr(Kind::DenseIntElements, type, data),
|
||||
bitsWidth(bitsWidth) {}
|
||||
|
||||
// TODO: returns APInts instead of IntegerAttr.
|
||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
||||
|
||||
|
@ -361,6 +357,12 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
friend class DenseElementsAttr;
|
||||
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
|
||||
size_t bitsWidth)
|
||||
: DenseElementsAttr(Kind::DenseIntElements, type, data),
|
||||
bitsWidth(bitsWidth) {}
|
||||
|
||||
~DenseIntElementsAttr() = delete;
|
||||
|
||||
size_t bitsWidth;
|
||||
|
@ -370,9 +372,6 @@ private:
|
|||
/// object. Each element is stored as a double.
|
||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
|
||||
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
|
||||
|
||||
// TODO: returns APFPs instead of FloatAttr.
|
||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
||||
|
||||
|
@ -384,9 +383,33 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
friend class DenseElementsAttr;
|
||||
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
|
||||
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
|
||||
~DenseFPElementsAttr() = delete;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a tensor constant with opaque
|
||||
/// content. This respresentation is for tensor constants which the compiler
|
||||
/// doesn't need to interpret.
|
||||
class OpaqueElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes);
|
||||
|
||||
StringRef getValue() const { return bytes; }
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::OpaqueElements;
|
||||
}
|
||||
|
||||
private:
|
||||
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
|
||||
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
|
||||
~OpaqueElementsAttr() = delete;
|
||||
StringRef bytes;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a sparse vector or tensor object.
|
||||
///
|
||||
/// This class uses COO (coordinate list) encoding to represent the sparse
|
||||
|
|
|
@ -106,8 +106,10 @@ public:
|
|||
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indicies,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values);
|
||||
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
StringRef bytes);
|
||||
|
||||
// Affine expressions and affine maps.
|
||||
AffineExpr getAffineDimExpr(unsigned position);
|
||||
|
|
|
@ -443,6 +443,14 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::OpaqueElements: {
|
||||
auto *eltsAttr = cast<OpaqueElementsAttr>(attr);
|
||||
os << "opaque<";
|
||||
printType(eltsAttr->getType());
|
||||
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"'
|
||||
<< '>';
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
case Attribute::Kind::DenseFPElements: {
|
||||
auto *eltsAttr = cast<DenseElementsAttr>(attr);
|
||||
|
|
|
@ -159,9 +159,14 @@ ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
|||
}
|
||||
|
||||
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indicies,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values) {
|
||||
return SparseElementsAttr::get(type, indicies, values);
|
||||
return SparseElementsAttr::get(type, indices, values);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
return OpaqueElementsAttr::get(type, bytes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -211,6 +211,23 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
|
|||
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
|
||||
}
|
||||
};
|
||||
|
||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> {
|
||||
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
||||
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue;
|
||||
using DenseMapInfo<OpaqueElementsAttr *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_pair(rhs->getType(), rhs->getValue());
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
namespace mlir {
|
||||
|
@ -316,6 +333,9 @@ public:
|
|||
using DenseElementsAttrSet =
|
||||
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
|
||||
DenseElementsAttrSet denseElementsAttrs;
|
||||
using OpaqueElementsAttrSet =
|
||||
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>;
|
||||
OpaqueElementsAttrSet opaqueElementsAttrs;
|
||||
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
|
||||
SparseElementsAttr *>
|
||||
sparseElementsAttrs;
|
||||
|
@ -888,6 +908,28 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
assert(isValidTensorElementType(type->getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
|
||||
// Look to see if this constant is already defined.
|
||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>();
|
||||
bytes = bytes.copy(impl.allocator);
|
||||
new (result) OpaqueElementsAttr(type, bytes);
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
||||
ArrayRef<char> data) {
|
||||
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::MemoryBuffer;
|
||||
|
@ -895,7 +896,26 @@ Attribute *Parser::parseAttribute() {
|
|||
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
|
||||
return function ? builder.getFunctionAttr(function) : nullptr;
|
||||
}
|
||||
|
||||
case Token::kw_opaque: {
|
||||
consumeToken(Token::kw_opaque);
|
||||
if (parseToken(Token::less, "expected '<' after 'opaque'"))
|
||||
return nullptr;
|
||||
auto *type = parseVectorOrTensorType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto val = getToken().getStringValue();
|
||||
if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
|
||||
return (emitError("opaque string should start with '0x'"), nullptr);
|
||||
val = val.substr(2);
|
||||
if (!std::all_of(val.begin(), val.end(),
|
||||
[](char c) { return llvm::isHexDigit(c); })) {
|
||||
return (emitError("opaque string only contains hex digits"), nullptr);
|
||||
}
|
||||
consumeToken(Token::string);
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getOpaqueElementsAttr(type, llvm::fromHex(val));
|
||||
}
|
||||
case Token::kw_splat: {
|
||||
consumeToken(Token::kw_splat);
|
||||
if (parseToken(Token::less, "expected '<' after 'splat'"))
|
||||
|
|
|
@ -110,6 +110,7 @@ TOK_KEYWORD(memref)
|
|||
TOK_KEYWORD(min)
|
||||
TOK_KEYWORD(mlfunc)
|
||||
TOK_KEYWORD(mod)
|
||||
TOK_KEYWORD(opaque)
|
||||
TOK_KEYWORD(return)
|
||||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(step)
|
||||
|
|
|
@ -699,3 +699,17 @@ cfgfunc @elementsattr_toolarge2() -> () {
|
|||
bb0:
|
||||
"foo"(){bar: dense<tensor<1xi8>, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @elementsattr_malformed_opaque() -> () {
|
||||
bb0:
|
||||
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @elementsattr_malformed_opaque1() -> () {
|
||||
bb0:
|
||||
"foo"(){bar: opaque<tensor<1xi8>, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}}
|
||||
}
|
|
@ -499,6 +499,21 @@ bb0:
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @opaquetensorattr
|
||||
cfgfunc @opaquetensorattr() -> () {
|
||||
bb0:
|
||||
// CHECK: "opaqueIntTensor"() {bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueIntTensor"(){bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
|
||||
// CHECK: "opaqueFloatTensor"() {bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueFloatTensor"(){bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
|
||||
|
||||
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
|
||||
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: cfgfunc @densetensorattr
|
||||
cfgfunc @densetensorattr() -> () {
|
||||
bb0:
|
||||
|
|
Loading…
Reference in New Issue