From 3d7ab2d2652ab562a79e725eebe63b786b2574e9 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 23 Oct 2018 13:44:04 -0700 Subject: [PATCH] Add support to opaque elements attributes For some of the constant vector / tesor, if the compiler doesn't need to interpret their elements content, they can be stored in this class to save the serialize / deserialize cost. syntax: `opaque<` tensor-type `,` opaque-string `>` opaque-string ::= `0x` [0-9a-fA-F]* PiperOrigin-RevId: 218399426 --- mlir/include/mlir/IR/Attributes.h | 39 ++++++++++++++++++++++------ mlir/include/mlir/IR/Builders.h | 4 ++- mlir/lib/IR/AsmPrinter.cpp | 8 ++++++ mlir/lib/IR/Builders.cpp | 9 +++++-- mlir/lib/IR/MLIRContext.cpp | 42 +++++++++++++++++++++++++++++++ mlir/lib/Parser/Parser.cpp | 22 +++++++++++++++- mlir/lib/Parser/TokenKinds.def | 1 + mlir/test/IR/invalid.mlir | 14 +++++++++++ mlir/test/IR/parser.mlir | 15 +++++++++++ 9 files changed, 142 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 28a3939fb1e1..5b7f930e9acf 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -49,6 +49,7 @@ public: SplatElements, DenseIntElements, DenseFPElements, + OpaqueElements, SparseElements, FIRST_ELEMENTS_ATTR = SplatElements, LAST_ELEMENTS_ATTR = SparseElements, @@ -335,11 +336,6 @@ private: /// object. class DenseIntElementsAttr : public DenseElementsAttr { public: - DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef data, - size_t bitsWidth) - : DenseElementsAttr(Kind::DenseIntElements, type, data), - bitsWidth(bitsWidth) {} - // TODO: returns APInts instead of IntegerAttr. void getValues(SmallVectorImpl &values) const; @@ -361,6 +357,12 @@ public: } private: + friend class DenseElementsAttr; + DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef data, + size_t bitsWidth) + : DenseElementsAttr(Kind::DenseIntElements, type, data), + bitsWidth(bitsWidth) {} + ~DenseIntElementsAttr() = delete; size_t bitsWidth; @@ -370,9 +372,6 @@ private: /// object. Each element is stored as a double. class DenseFPElementsAttr : public DenseElementsAttr { public: - DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef data) - : DenseElementsAttr(Kind::DenseFPElements, type, data) {} - // TODO: returns APFPs instead of FloatAttr. void getValues(SmallVectorImpl &values) const; @@ -384,9 +383,33 @@ public: } private: + friend class DenseElementsAttr; + DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef data) + : DenseElementsAttr(Kind::DenseFPElements, type, data) {} ~DenseFPElementsAttr() = delete; }; +/// An attribute represents a reference to a tensor constant with opaque +/// content. This respresentation is for tensor constants which the compiler +/// doesn't need to interpret. +class OpaqueElementsAttr : public ElementsAttr { +public: + static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes); + + StringRef getValue() const { return bytes; } + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(const Attribute *attr) { + return attr->getKind() == Kind::OpaqueElements; + } + +private: + OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes) + : ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {} + ~OpaqueElementsAttr() = delete; + StringRef bytes; +}; + /// An attribute represents a reference to a sparse vector or tensor object. /// /// This class uses COO (coordinate list) encoding to represent the sparse diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 9c71296d35f6..edfcfcb0f067 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -106,8 +106,10 @@ public: ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type, ArrayRef data); ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type, - DenseIntElementsAttr *indicies, + DenseIntElementsAttr *indices, DenseElementsAttr *values); + ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type, + StringRef bytes); // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a01d82695427..dad9f952ce02 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -443,6 +443,14 @@ void ModulePrinter::printAttribute(const Attribute *attr) { } break; } + case Attribute::Kind::OpaqueElements: { + auto *eltsAttr = cast(attr); + os << "opaque<"; + printType(eltsAttr->getType()); + os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"' + << '>'; + break; + } case Attribute::Kind::DenseIntElements: case Attribute::Kind::DenseFPElements: { auto *eltsAttr = cast(attr); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 66192f0a867a..5759d8477cad 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -159,9 +159,14 @@ ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type, } ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type, - DenseIntElementsAttr *indicies, + DenseIntElementsAttr *indices, DenseElementsAttr *values) { - return SparseElementsAttr::get(type, indicies, values); + return SparseElementsAttr::get(type, indices, values); +} + +ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type, + StringRef bytes) { + return OpaqueElementsAttr::get(type, bytes); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index ee19464b7772..6705619696eb 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -211,6 +211,23 @@ struct DenseElementsAttrInfo : DenseMapInfo { return lhs == std::make_pair(rhs->getType(), rhs->getRawData()); } }; + +struct OpaqueElementsAttrInfo : DenseMapInfo { + using KeyTy = std::pair; + using DenseMapInfo::getHashValue; + using DenseMapInfo::isEqual; + + static unsigned getHashValue(KeyTy key) { + return hash_combine( + key.first, hash_combine_range(key.second.begin(), key.second.end())); + } + + static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == std::make_pair(rhs->getType(), rhs->getValue()); + } +}; } // end anonymous namespace. namespace mlir { @@ -316,6 +333,9 @@ public: using DenseElementsAttrSet = DenseSet; DenseElementsAttrSet denseElementsAttrs; + using OpaqueElementsAttrSet = + DenseSet; + OpaqueElementsAttrSet opaqueElementsAttrs; DenseMap, SparseElementsAttr *> sparseElementsAttrs; @@ -888,6 +908,28 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, return *existing.first = result; } +OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type, + StringRef bytes) { + assert(isValidTensorElementType(type->getElementType()) && + "Input element type should be a valid tensor element type"); + + auto &impl = type->getContext()->getImpl(); + + // Look to see if this constant is already defined. + OpaqueElementsAttrInfo::KeyTy key({type, bytes}); + auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key); + + // If we already have it, return that value. + if (!existing.second) + return *existing.first; + + // Otherwise, allocate a new one, unique it and return it. + auto *result = impl.allocator.Allocate(); + bytes = bytes.copy(impl.allocator); + new (result) OpaqueElementsAttr(type, bytes); + return *existing.first = result; +} + DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type, ArrayRef data) { auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 2dd949e25c12..171b11938b50 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -41,6 +41,7 @@ #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" +#include using namespace mlir; using llvm::MemoryBuffer; @@ -895,7 +896,26 @@ Attribute *Parser::parseAttribute() { auto *function = resolveFunctionReference(nameStr, nameLoc, fnType); return function ? builder.getFunctionAttr(function) : nullptr; } - + case Token::kw_opaque: { + consumeToken(Token::kw_opaque); + if (parseToken(Token::less, "expected '<' after 'opaque'")) + return nullptr; + auto *type = parseVectorOrTensorType(); + if (!type) + return nullptr; + auto val = getToken().getStringValue(); + if (val.size() < 2 || val[0] != '0' || val[1] != 'x') + return (emitError("opaque string should start with '0x'"), nullptr); + val = val.substr(2); + if (!std::all_of(val.begin(), val.end(), + [](char c) { return llvm::isHexDigit(c); })) { + return (emitError("opaque string only contains hex digits"), nullptr); + } + consumeToken(Token::string); + if (parseToken(Token::greater, "expected '>'")) + return nullptr; + return builder.getOpaqueElementsAttr(type, llvm::fromHex(val)); + } case Token::kw_splat: { consumeToken(Token::kw_splat); if (parseToken(Token::less, "expected '<' after 'splat'")) diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 0d334bd7353f..3431430b83fd 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -110,6 +110,7 @@ TOK_KEYWORD(memref) TOK_KEYWORD(min) TOK_KEYWORD(mlfunc) TOK_KEYWORD(mod) +TOK_KEYWORD(opaque) TOK_KEYWORD(return) TOK_KEYWORD(size) TOK_KEYWORD(step) diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index f1621a0bc4c3..3d7b01748fdc 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -699,3 +699,17 @@ cfgfunc @elementsattr_toolarge2() -> () { bb0: "foo"(){bar: dense, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}} } + +// ----- + +cfgfunc @elementsattr_malformed_opaque() -> () { +bb0: + "foo"(){bar: opaque, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}} +} + +// ----- + +cfgfunc @elementsattr_malformed_opaque1() -> () { +bb0: + "foo"(){bar: opaque, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}} +} \ No newline at end of file diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 9d0f811eb1fc..897c35451301 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -499,6 +499,21 @@ bb0: return } +// CHECK-LABEL: cfgfunc @opaquetensorattr +cfgfunc @opaquetensorattr() -> () { +bb0: +// CHECK: "opaqueIntTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () + "opaqueIntTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueFloatTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () + "opaqueFloatTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () + +// CHECK: "opaqueStringTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () + "opaqueStringTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueResourceTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () + "opaqueResourceTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () + return +} + // CHECK-LABEL: cfgfunc @densetensorattr cfgfunc @densetensorattr() -> () { bb0: