Add dialect-specific decoding for opaque constants.

Associates opaque constants with a particular dialect. Adds general mechanism to register dialect-specific hooks defined in external components. Adds hooks to decode opaque tensor constant and extract an element of an opaque tensor constant.

This CL does not change the existing mechanism for registering constant folding hook yet. One thing at a time.

PiperOrigin-RevId: 233544757
This commit is contained in:
Tatiana Shpeisman 2019-02-11 22:51:34 -08:00 committed by jpienaar
parent 4b88e7a245
commit 2e6cd60d3b
13 changed files with 207 additions and 24 deletions

View File

@ -24,6 +24,7 @@
namespace mlir {
class AffineMap;
class Dialect;
class Function;
class FunctionAttr;
class FunctionType;
@ -425,17 +426,28 @@ public:
/// An attribute represents a reference to a tensor constant with opaque
/// content. This respresentation is for tensor constants which the compiler
/// doesn't need to interpret.
/// may not need to interpret. This attribute is always associated with
/// a particular dialect, which provides a method to convert tensor
/// representation to a non-opaque format.
class OpaqueElementsAttr : public ElementsAttr {
public:
using ElementsAttr::ElementsAttr;
using ImplType = detail::OpaqueElementsAttributeStorage;
using ValueType = StringRef;
static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
static OpaqueElementsAttr get(Dialect *dialect, VectorOrTensorType type,
StringRef bytes);
StringRef getValue() const;
/// Decodes the attribute value using dialect-specific decoding hook.
/// Returns false if decoding is successful. If not, returns true and leaves
/// 'result' argument unspecified.
bool decode(ElementsAttr &result);
/// Returns dialect associated with this opaque constant.
Dialect *getDialect() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; }
};

View File

@ -113,7 +113,8 @@ public:
ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
ElementsAttr getOpaqueElementsAttr(Dialect *dialect, VectorOrTensorType type,
StringRef bytes);
// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 32-/64-bit float types, and vector or ranked
// tensor of them. Returns null attribute otherwise.

View File

@ -29,8 +29,12 @@ class AffineMap;
class IntegerSet;
class Type;
using DialectConstantDecodeHook =
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
using DialectConstantFoldHook = std::function<bool(
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
using DialectTypeParserHook =
std::function<Type(StringRef, Location, MLIRContext *)>;
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
@ -59,6 +63,23 @@ public:
[](const Instruction *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return true; };
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
/// into a tensor with non-opaque content. If decoding is successful, this
/// method returns false and sets 'output' attribute. If not, it returns true
/// and leaves 'output' unspecified. The default hook fails to decode.
DialectConstantDecodeHook decodeHook =
[](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
/// Registered hook to extract an element from an opaque constant associated
/// with this dialect. If element has been successfully extracted, this
/// method returns that element. If not, it returns an empty attribute.
/// The default hook fails to extract an element.
DialectExtractElementHook extractElementHook =
[](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
return Attribute();
};
/// Registered parsing/printing hooks for types registered to the dialect.
DialectTypeParserHook typeParseHook = nullptr;
/// Note: The data printed for the provided type must not include any '"'

View File

@ -0,0 +1,78 @@
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_HOOKS_H
#define MLIR_IR_DIALECT_HOOKS_H
#include "mlir/IR/Dialect.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// Dialect hooks allow external components to register their functions to
/// be called for specific tasks specialized per dialect, such as decoding
/// of opaque constants. To register concrete dialect hooks, one should
/// define a DialectHooks subclass and use it as a template
/// argument to DialectHooksRegistration. For example,
/// class MyHooks : public DialectHooks {...};
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
DialectExtractElementHook getExtractElementHook() { return nullptr; }
};
/// Registers a function that will set hooks in the registered dialects
/// based on information coming from DialectHooksRegistration.
void registerDialectHooksSetter(const DialectHooksSetter &function);
/// DialectHooksRegistration provides a global initialiser that registers
/// a dialect hooks setter routine.
/// Usage:
///
/// // At namespace scope.
/// static DialectHooksRegistration<MyHooks, MyDialect> unused;
template <typename ConcreteHooks> struct DialectHooksRegistration {
DialectHooksRegistration(StringRef dialectName) {
registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
<< dialectName << "'\n";
abort();
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
dialect->extractElementHook = h;
});
}
};
} // namespace mlir
#endif

View File

@ -606,6 +606,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
case Attribute::Kind::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<";
os << '"' << eltsAttr.getDialect()->getNamespace() << "\", ";
printType(eltsAttr.getType());
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
break;

View File

@ -138,6 +138,7 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
Dialect *dialect;
StringRef bytes;
};

View File

@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Types.h"
@ -341,6 +342,16 @@ StringRef OpaqueElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->bytes;
}
Dialect *OpaqueElementsAttr::getDialect() const {
return static_cast<ImplType *>(attr)->dialect;
}
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
if (auto *d = getDialect())
return d->decodeHook(*this, result);
return true;
}
/// SparseElementsAttr
DenseIntElementsAttr SparseElementsAttr::getIndices() const {

View File

@ -183,9 +183,10 @@ ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
return SparseElementsAttr::get(type, indices, values);
}
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect,
VectorOrTensorType type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
return OpaqueElementsAttr::get(dialect, type, bytes);
}
Attribute Builder::getZeroAttr(Type type) {

View File

@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
@ -28,6 +29,10 @@ static llvm::ManagedStatic<SmallVector<DialectAllocatorFunction, 8>>
static llvm::ManagedStatic<SmallVector<ConstantFoldHookAllocator, 8>>
constantFoldHookRegistry;
// Registry for functions that set dialect hooks.
static llvm::ManagedStatic<SmallVector<DialectHooksSetter, 8>>
dialectHooksRegistry;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) {
@ -44,6 +49,16 @@ void mlir::registerConstantFoldHook(const ConstantFoldHookAllocator &function) {
constantFoldHookRegistry->push_back(function);
}
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistreation template.
void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) {
assert(
function &&
"Attempting to register an empty dialect hooks initialization function");
dialectHooksRegistry->push_back(function);
}
/// Registers all dialects and their const folding hooks with the specified
/// MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
@ -51,6 +66,9 @@ void mlir::registerAllDialects(MLIRContext *context) {
fn(context);
for (const auto &fn : *constantFoldHookRegistry)
fn(context);
for (const auto &fn : *dialectHooksRegistry) {
fn(context);
}
}
Dialect::Dialect(StringRef namePrefix, MLIRContext *context)

View File

@ -207,23 +207,26 @@ struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType, StringRef>;
// Opaque element attributes are uniqued based on their dialect, type and
// value.
using KeyTy = std::tuple<Dialect *, VectorOrTensorType, StringRef>;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
static unsigned getHashValue(OpaqueElementsAttributeStorage *key) {
return getHashValue(KeyTy(key->type, key->bytes));
return getHashValue(KeyTy(key->dialect, key->type, key->bytes));
}
static unsigned getHashValue(KeyTy key) {
return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end()));
auto bytes = std::get<2>(key);
return hash_combine(std::get<0>(key), std::get<1>(key),
hash_combine_range(bytes.begin(), bytes.end()));
}
static bool isEqual(const KeyTy &lhs,
const OpaqueElementsAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_pair(rhs->type, rhs->bytes);
return lhs == std::make_tuple(rhs->dialect, rhs->type, rhs->bytes);
}
};
@ -1139,7 +1142,8 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
return get(type, data);
}
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect,
VectorOrTensorType type,
StringRef bytes) {
assert(TensorType::isValidElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
@ -1147,7 +1151,7 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
OpaqueElementsAttrInfo::KeyTy key(dialect, type, bytes);
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value.
@ -1156,9 +1160,13 @@ OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
// Otherwise, allocate a new one, unique it and return it.
auto *result = impl.allocator.Allocate<OpaqueElementsAttributeStorage>();
// TODO: Provide a way to avoid copying content of large opaque tensors
// This will likely require a new reference attribute kind.
bytes = bytes.copy(impl.allocator);
new (result) OpaqueElementsAttributeStorage{
{{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type},
dialect,
bytes};
return *existing.first = result;
}

View File

@ -1047,9 +1047,26 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
if (getToken().getKind() != Token::string)
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
auto *dialect = builder.getContext()->getRegisteredDialect(name);
// TODO(shpeisman): Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
if (!dialect)
return (emitError("no registered dialect with namespace '" + name + "'"),
nullptr);
consumeToken(Token::string);
if (parseToken(Token::comma, "expected ','"))
return nullptr;
auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
if (getToken().getKind() != Token::string)
return (emitError("opaque string should start with '0x'"), nullptr);
auto val = getToken().getStringValue();
@ -1063,7 +1080,7 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::string);
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getOpaqueElementsAttr(type, llvm::fromHex(val));
return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
}
case Token::kw_splat: {
consumeToken(Token::kw_splat);

View File

@ -681,14 +681,28 @@ func @elementsattr_toolarge2() -> () {
func @elementsattr_malformed_opaque() -> () {
^bb0:
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}
"foo"(){bar: opaque<tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{expected dialect namespace}}
}
// -----
func @elementsattr_malformed_opaque1() -> () {
^bb0:
"foo"(){bar: opaque<tensor<1xi8>, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}}
"foo"(){bar: opaque<"", tensor<1xi8>, "0xQZz123">} : () -> () // expected-error {{opaque string only contains hex digits}}
}
// -----
func @elementsattr_malformed_opaque2() -> () {
^bb0:
"foo"(){bar: opaque<"", tensor<1xi8>, "00abc">} : () -> () // expected-error {{opaque string should start with '0x'}}
}
// -----
func @elementsattr_malformed_opaque3() -> () {
^bb0:
"foo"(){bar: opaque<"t", tensor<1xi8>, "0xabc">} : () -> () // expected-error {{no registered dialect with namespace 't'}}
}
// -----
@ -783,7 +797,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined t
func @complex_loops() {
for %i1 = 1 to 100 {
// expected-error @+1 {{expected '"' in string literal}}
"opaqueIntTensor"(){bar: opaque<tensor<2x1x4xi32>, "0x686]>} : () -> ()
"opaqueIntTensor"(){bar: opaque<"", tensor<2x1x4xi32>, "0x686]>} : () -> ()
// -----

View File

@ -588,15 +588,15 @@ func @splattensorattr() -> () {
// CHECK-LABEL: func @opaquetensorattr
func @opaquetensorattr() -> () {
^bb0:
// CHECK: "opaqueIntTensor"() {bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
"opaqueIntTensor"(){bar: opaque<tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueFloatTensor"() {bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
"opaqueFloatTensor"(){bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueIntTensor"() {bar: opaque<"tf", tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
"opaqueIntTensor"(){bar: opaque<"tf", tensor<2x1x4xi32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueFloatTensor"() {bar: opaque<"tf", tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
"opaqueFloatTensor"(){bar: opaque<"tf", tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueStringTensor"() {bar: opaque<"tf", tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
"opaqueStringTensor"(){bar: opaque<"tf", tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueResourceTensor"() {bar: opaque<"tf", tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
"opaqueResourceTensor"(){bar: opaque<"tf", tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
return
}