From 2e6cd60d3b7a77c3067ba67e7c2d47f5855e1852 Mon Sep 17 00:00:00 2001 From: Tatiana Shpeisman Date: Mon, 11 Feb 2019 22:51:34 -0800 Subject: [PATCH] Add dialect-specific decoding for opaque constants. Associates opaque constants with a particular dialect. Adds general mechanism to register dialect-specific hooks defined in external components. Adds hooks to decode opaque tensor constant and extract an element of an opaque tensor constant. This CL does not change the existing mechanism for registering constant folding hook yet. One thing at a time. PiperOrigin-RevId: 233544757 --- mlir/include/mlir/IR/Attributes.h | 16 +++++- mlir/include/mlir/IR/Builders.h | 3 +- mlir/include/mlir/IR/Dialect.h | 21 ++++++++ mlir/include/mlir/IR/DialectHooks.h | 78 +++++++++++++++++++++++++++++ mlir/lib/IR/AsmPrinter.cpp | 1 + mlir/lib/IR/AttributeDetail.h | 1 + mlir/lib/IR/Attributes.cpp | 11 ++++ mlir/lib/IR/Builders.cpp | 5 +- mlir/lib/IR/Dialect.cpp | 18 +++++++ mlir/lib/IR/MLIRContext.cpp | 22 +++++--- mlir/lib/Parser/Parser.cpp | 19 ++++++- mlir/test/IR/invalid.mlir | 20 ++++++-- mlir/test/IR/parser.mlir | 16 +++--- 13 files changed, 207 insertions(+), 24 deletions(-) create mode 100644 mlir/include/mlir/IR/DialectHooks.h diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 6e32dff31161..bb66a5dc8935 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -24,6 +24,7 @@ namespace mlir { class AffineMap; +class Dialect; class Function; class FunctionAttr; class FunctionType; @@ -425,17 +426,28 @@ public: /// An attribute represents a reference to a tensor constant with opaque /// content. This respresentation is for tensor constants which the compiler -/// doesn't need to interpret. +/// may not need to interpret. This attribute is always associated with +/// a particular dialect, which provides a method to convert tensor +/// representation to a non-opaque format. class OpaqueElementsAttr : public ElementsAttr { public: using ElementsAttr::ElementsAttr; using ImplType = detail::OpaqueElementsAttributeStorage; using ValueType = StringRef; - static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes); + static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type, + StringRef bytes); StringRef getValue() const; + /// Decodes the attribute value using dialect-specific decoding hook. + /// Returns false if decoding is successful. If not, returns true and leaves + /// 'result' argument unspecified. + bool decode(ElementsAttr &result); + + /// Returns dialect associated with this opaque constant. + Dialect *getDialect() const; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; } }; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 1d9421b909fd..71a17c5a13b0 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,7 +113,8 @@ public: ElementsAttr getSparseElementsAttr(VectorOrTensorType type, DenseIntElementsAttr indices, DenseElementsAttr values); - ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes); + ElementsAttr getOpaqueElementsAttr(Dialect *dialect, VectorOrTensorType type, + StringRef bytes); // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 32-/64-bit float types, and vector or ranked // tensor of them. Returns null attribute otherwise. diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index ba8c586725a8..9f7732e37766 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -29,8 +29,12 @@ class AffineMap; class IntegerSet; class Type; +using DialectConstantDecodeHook = + std::function; using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; +using DialectExtractElementHook = + std::function)>; using DialectTypeParserHook = std::function; using DialectTypePrinterHook = std::function; @@ -59,6 +63,23 @@ public: [](const Instruction *op, ArrayRef operands, SmallVectorImpl &results) { return true; }; + /// Registered hook to decode opaque constants associated with this + /// dialect. The hook function attempts to decode an opaque constant tensor + /// into a tensor with non-opaque content. If decoding is successful, this + /// method returns false and sets 'output' attribute. If not, it returns true + /// and leaves 'output' unspecified. The default hook fails to decode. + DialectConstantDecodeHook decodeHook = + [](const OpaqueElementsAttr input, ElementsAttr &output) { return true; }; + + /// Registered hook to extract an element from an opaque constant associated + /// with this dialect. If element has been successfully extracted, this + /// method returns that element. If not, it returns an empty attribute. + /// The default hook fails to extract an element. + DialectExtractElementHook extractElementHook = + [](const OpaqueElementsAttr input, ArrayRef index) { + return Attribute(); + }; + /// Registered parsing/printing hooks for types registered to the dialect. DialectTypeParserHook typeParseHook = nullptr; /// Note: The data printed for the provided type must not include any '"' diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h new file mode 100644 index 000000000000..dbfb1ab33c70 --- /dev/null +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -0,0 +1,78 @@ +//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines abstraction and registration mechanism for dialect hooks. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECT_HOOKS_H +#define MLIR_IR_DIALECT_HOOKS_H + +#include "mlir/IR/Dialect.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +using DialectHooksSetter = std::function; + +/// Dialect hooks allow external components to register their functions to +/// be called for specific tasks specialized per dialect, such as decoding +/// of opaque constants. To register concrete dialect hooks, one should +/// define a DialectHooks subclass and use it as a template +/// argument to DialectHooksRegistration. For example, +/// class MyHooks : public DialectHooks {...}; +/// static DialectHooksRegistration hooksReg; +/// The subclass should override DialectHook methods for supported hooks. +class DialectHooks { +public: + // Returns hook to decode opaque constant tensor. + DialectConstantDecodeHook getDecodeHook() { return nullptr; } + // Returns hook to extract an element of an opaque constant tensor. + DialectExtractElementHook getExtractElementHook() { return nullptr; } +}; + +/// Registers a function that will set hooks in the registered dialects +/// based on information coming from DialectHooksRegistration. +void registerDialectHooksSetter(const DialectHooksSetter &function); + +/// DialectHooksRegistration provides a global initialiser that registers +/// a dialect hooks setter routine. +/// Usage: +/// +/// // At namespace scope. +/// static DialectHooksRegistration unused; +template struct DialectHooksRegistration { + DialectHooksRegistration(StringRef dialectName) { + registerDialectHooksSetter([dialectName](MLIRContext *ctx) { + Dialect *dialect = ctx->getRegisteredDialect(dialectName); + if (!dialect) { + llvm::errs() << "error: cannot register hooks for unknown dialect '" + << dialectName << "'\n"; + abort(); + } + // Set hooks. + ConcreteHooks hooks; + if (auto h = hooks.getDecodeHook()) + dialect->decodeHook = h; + if (auto h = hooks.getExtractElementHook()) + dialect->extractElementHook = h; + }); + } +}; + +} // namespace mlir + +#endif diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index fd30808eb557..0dbf40d84183 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -606,6 +606,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr, case Attribute::Kind::OpaqueElements: { auto eltsAttr = attr.cast(); os << "opaque<"; + os << '"' << eltsAttr.getDialect()->getNamespace() << "\", "; printType(eltsAttr.getType()); os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>'; break; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 8e8604b751d7..ef136eab5b41 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -138,6 +138,7 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage { /// An attribute representing a reference to a tensor constant with opaque /// content. struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage { + Dialect *dialect; StringRef bytes; }; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index e35e7771d6c6..5b2b02c1c556 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Attributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Types.h" @@ -341,6 +342,16 @@ StringRef OpaqueElementsAttr::getValue() const { return static_cast(attr)->bytes; } +Dialect *OpaqueElementsAttr::getDialect() const { + return static_cast(attr)->dialect; +} + +bool OpaqueElementsAttr::decode(ElementsAttr &result) { + if (auto *d = getDialect()) + return d->decodeHook(*this, result); + return true; +} + /// SparseElementsAttr DenseIntElementsAttr SparseElementsAttr::getIndices() const { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 3adf5f27bda1..1b197c36760b 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -183,9 +183,10 @@ ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type, return SparseElementsAttr::get(type, indices, values); } -ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type, +ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, + VectorOrTensorType type, StringRef bytes) { - return OpaqueElementsAttr::get(type, bytes); + return OpaqueElementsAttr::get(dialect, type, bytes); } Attribute Builder::getZeroAttr(Type type) { diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index f6a163b18b3b..249c9d84c1f7 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -16,6 +16,7 @@ // ============================================================================= #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectHooks.h" #include "mlir/IR/MLIRContext.h" #include "llvm/Support/ManagedStatic.h" using namespace mlir; @@ -28,6 +29,10 @@ static llvm::ManagedStatic> static llvm::ManagedStatic> constantFoldHookRegistry; +// Registry for functions that set dialect hooks. +static llvm::ManagedStatic> + dialectHooksRegistry; + /// Registers a specific dialect creation function with the system, typically /// used through the DialectRegistration template. void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) { @@ -44,6 +49,16 @@ void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) { constantFoldHookRegistry->push_back(function); } +/// Registers a function to set specific hooks for a specific dialect, typically +/// used through the DialectHooksRegistreation template. +void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { + assert( + function && + "Attempting to register an empty dialect hooks initialization function"); + + dialectHooksRegistry->push_back(function); +} + /// Registers all dialects and their const folding hooks with the specified /// MLIRContext. void mlir::registerAllDialects(MLIRContext *context) { @@ -51,6 +66,9 @@ void mlir::registerAllDialects(MLIRContext *context) { fn(context); for (const auto &fn : *constantFoldHookRegistry) fn(context); + for (const auto &fn : *dialectHooksRegistry) { + fn(context); + } } Dialect::Dialect(StringRef namePrefix, MLIRContext *context) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 7c0100d70b28..c33174ba8d34 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -207,23 +207,26 @@ struct DenseElementsAttrInfo : DenseMapInfo { }; struct OpaqueElementsAttrInfo : DenseMapInfo { - using KeyTy = std::pair; + // Opaque element attributes are uniqued based on their dialect, type and + // value. + using KeyTy = std::tuple; using DenseMapInfo::isEqual; static unsigned getHashValue(OpaqueElementsAttributeStorage *key) { - return getHashValue(KeyTy(key->type, key->bytes)); + return getHashValue(KeyTy(key->dialect, key->type, key->bytes)); } static unsigned getHashValue(KeyTy key) { - return hash_combine( - key.first, hash_combine_range(key.second.begin(), key.second.end())); + auto bytes = std::get<2>(key); + return hash_combine(std::get<0>(key), std::get<1>(key), + hash_combine_range(bytes.begin(), bytes.end())); } static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttributeStorage *rhs) { if (rhs == getEmptyKey() || rhs == getTombstoneKey()) return false; - return lhs == std::make_pair(rhs->type, rhs->bytes); + return lhs == std::make_tuple(rhs->dialect, rhs->type, rhs->bytes); } }; @@ -1139,7 +1142,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, return get(type, data); } -OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, +OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, + VectorOrTensorType type, StringRef bytes) { assert(TensorType::isValidElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); @@ -1147,7 +1151,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, auto &impl = type.getContext()->getImpl(); // Look to see if this constant is already defined. - OpaqueElementsAttrInfo::KeyTy key({type, bytes}); + OpaqueElementsAttrInfo::KeyTy key(dialect, type, bytes); auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key); // If we already have it, return that value. @@ -1156,9 +1160,13 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type, // Otherwise, allocate a new one, unique it and return it. auto *result = impl.allocator.Allocate(); + + // TODO: Provide a way to avoid copying content of large opaque tensors + // This will likely require a new reference attribute kind. bytes = bytes.copy(impl.allocator); new (result) OpaqueElementsAttributeStorage{ {{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type}, + dialect, bytes}; return *existing.first = result; } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 4f2f4a9b528d..ccb618f13c60 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1047,9 +1047,26 @@ Attribute Parser::parseAttribute(Type type) { consumeToken(Token::kw_opaque); if (parseToken(Token::less, "expected '<' after 'opaque'")) return nullptr; + + if (getToken().getKind() != Token::string) + return (emitError("expected dialect namespace"), nullptr); + auto name = getToken().getStringValue(); + auto *dialect = builder.getContext()->getRegisteredDialect(name); + // TODO(shpeisman): Allow for having an unknown dialect on an opaque + // attribute. Otherwise, it can't be roundtripped without having the dialect + // registered. + if (!dialect) + return (emitError("no registered dialect with namespace '" + name + "'"), + nullptr); + + consumeToken(Token::string); + if (parseToken(Token::comma, "expected ','")) + return nullptr; + auto type = parseVectorOrTensorType(); if (!type) return nullptr; + if (getToken().getKind() != Token::string) return (emitError("opaque string should start with '0x'"), nullptr); auto val = getToken().getStringValue(); @@ -1063,7 +1080,7 @@ Attribute Parser::parseAttribute(Type type) { consumeToken(Token::string); if (parseToken(Token::greater, "expected '>'")) return nullptr; - return builder.getOpaqueElementsAttr(type, llvm::fromHex(val)); + return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val)); } case Token::kw_splat: { consumeToken(Token::kw_splat); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 249033885db1..cd48a940d1e3 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -681,14 +681,28 @@ func @elementsattr_toolarge2() -> () { func @elementsattr_malformed_opaque() -> () { ^bb0: - "foo"(){bar: opaque, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}} + "foo"(){bar: opaque, "0xQZz123">} : () -> () // expected-error {{expected dialect namespace}} } // ----- func @elementsattr_malformed_opaque1() -> () { ^bb0: - "foo"(){bar: opaque, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}} + "foo"(){bar: opaque<"", tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}} +} + +// ----- + +func @elementsattr_malformed_opaque2() -> () { +^bb0: + "foo"(){bar: opaque<"", tensor<1xi8>, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}} +} + +// ----- + +func @elementsattr_malformed_opaque3() -> () { +^bb0: + "foo"(){bar: opaque<"t", tensor<1xi8>, "0xabc">} : () -> () // expected-error {{no registered dialect with namespace 't'}} } // ----- @@ -783,7 +797,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t func @complex_loops() { for %i1 = 1 to 100 { // expected-error @+1 {{expected '"' in string literal}} - "opaqueIntTensor"(){bar: opaque, "0x686]>} : () -> () + "opaqueIntTensor"(){bar: opaque<"", tensor<2x1x4xi32>, "0x686]>} : () -> () // ----- diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 2fc9a15687ac..e3a48c55a1c8 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -588,15 +588,15 @@ func @splattensorattr() -> () { // CHECK-LABEL: func @opaquetensorattr func @opaquetensorattr() -> () { ^bb0: -// CHECK: "opaqueIntTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () - "opaqueIntTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () -// CHECK: "opaqueFloatTensor"() {bar: opaque, "0x68656C6C6F">} : () -> () - "opaqueFloatTensor"(){bar: opaque, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueIntTensor"() {bar: opaque<"tf", tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> () + "opaqueIntTensor"(){bar: opaque<"tf", tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueFloatTensor"() {bar: opaque<"tf", tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> () + "opaqueFloatTensor"(){bar: opaque<"tf", tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> () -// CHECK: "opaqueStringTensor"() {bar: opaque>, "0x68656C6C6F">} : () -> () - "opaqueStringTensor"(){bar: opaque>, "0x68656C6C6F">} : () -> () -// CHECK: "opaqueResourceTensor"() {bar: opaque>, "0x68656C6C6F">} : () -> () - "opaqueResourceTensor"(){bar: opaque>, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueStringTensor"() {bar: opaque<"tf", tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> () + "opaqueStringTensor"(){bar: opaque<"tf", tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> () +// CHECK: "opaqueResourceTensor"() {bar: opaque<"tf", tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> () + "opaqueResourceTensor"(){bar: opaque<"tf", tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> () return }