forked from OSchip/llvm-project
732 lines
27 KiB
C++
732 lines
27 KiB
C++
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This holds implementation details of Attribute.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef ATTRIBUTEDETAIL_H_
|
|
#define ATTRIBUTEDETAIL_H_
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Identifier.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/Support/StorageUniquer.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/PointerIntPair.h"
|
|
#include "llvm/Support/TrailingObjects.h"
|
|
|
|
namespace mlir {
|
|
namespace detail {
|
|
// An attribute representing a reference to an affine map.
|
|
struct AffineMapAttributeStorage : public AttributeStorage {
|
|
using KeyTy = AffineMap;
|
|
|
|
AffineMapAttributeStorage(AffineMap value)
|
|
: AttributeStorage(IndexType::get(value.getContext())), value(value) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const { return key == value; }
|
|
|
|
/// Construct a new storage instance.
|
|
static AffineMapAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
return new (allocator.allocate<AffineMapAttributeStorage>())
|
|
AffineMapAttributeStorage(key);
|
|
}
|
|
|
|
AffineMap value;
|
|
};
|
|
|
|
/// An attribute representing an array of other attributes.
|
|
struct ArrayAttributeStorage : public AttributeStorage {
|
|
using KeyTy = ArrayRef<Attribute>;
|
|
|
|
ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const { return key == value; }
|
|
|
|
/// Construct a new storage instance.
|
|
static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
const KeyTy &key) {
|
|
return new (allocator.allocate<ArrayAttributeStorage>())
|
|
ArrayAttributeStorage(allocator.copyInto(key));
|
|
}
|
|
|
|
ArrayRef<Attribute> value;
|
|
};
|
|
|
|
/// An attribute representing a boolean value.
|
|
struct BoolAttributeStorage : public AttributeStorage {
|
|
using KeyTy = std::pair<MLIRContext *, bool>;
|
|
|
|
BoolAttributeStorage(Type type, bool value)
|
|
: AttributeStorage(type), value(value) {}
|
|
|
|
/// We only check equality for and hash with the boolean key parameter.
|
|
bool operator==(const KeyTy &key) const { return key.second == value; }
|
|
static unsigned hashKey(const KeyTy &key) {
|
|
return llvm::hash_value(key.second);
|
|
}
|
|
|
|
static BoolAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
const KeyTy &key) {
|
|
return new (allocator.allocate<BoolAttributeStorage>())
|
|
BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
|
|
}
|
|
|
|
bool value;
|
|
};
|
|
|
|
/// An attribute representing a dictionary of sorted named attributes.
|
|
struct DictionaryAttributeStorage final
|
|
: public AttributeStorage,
|
|
private llvm::TrailingObjects<DictionaryAttributeStorage,
|
|
NamedAttribute> {
|
|
using KeyTy = ArrayRef<NamedAttribute>;
|
|
|
|
/// Given a list of NamedAttribute's, canonicalize the list (sorting
|
|
/// by name) and return the unique'd result.
|
|
static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs);
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const { return key == getElements(); }
|
|
|
|
/// Construct a new storage instance.
|
|
static DictionaryAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
|
|
auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>(
|
|
key.size());
|
|
auto rawMem = allocator.allocate(size, alignof(NamedAttribute));
|
|
|
|
// Initialize the storage and trailing attribute list.
|
|
auto result = ::new (rawMem) DictionaryAttributeStorage(key.size());
|
|
std::uninitialized_copy(key.begin(), key.end(),
|
|
result->getTrailingObjects<NamedAttribute>());
|
|
return result;
|
|
}
|
|
|
|
/// Return the elements of this dictionary attribute.
|
|
ArrayRef<NamedAttribute> getElements() const {
|
|
return {getTrailingObjects<NamedAttribute>(), numElements};
|
|
}
|
|
|
|
private:
|
|
friend class llvm::TrailingObjects<DictionaryAttributeStorage,
|
|
NamedAttribute>;
|
|
|
|
// This is used by the llvm::TrailingObjects base class.
|
|
size_t numTrailingObjects(OverloadToken<NamedAttribute>) const {
|
|
return numElements;
|
|
}
|
|
DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {}
|
|
|
|
/// This is the number of attributes.
|
|
const unsigned numElements;
|
|
};
|
|
|
|
/// An attribute representing a floating point value.
|
|
struct FloatAttributeStorage final
|
|
: public AttributeStorage,
|
|
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
|
|
using KeyTy = std::pair<Type, APFloat>;
|
|
|
|
FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
|
|
size_t numObjects)
|
|
: AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
|
|
|
|
/// Key equality and hash functions.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key.first == getType() && key.second.bitwiseIsEqual(getValue());
|
|
}
|
|
static unsigned hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
|
|
}
|
|
|
|
/// Construct a key with a type and double.
|
|
static KeyTy getKey(Type type, double value) {
|
|
// Treat BF16 as double because it is not supported in LLVM's APFloat.
|
|
// TODO(b/121118307): add BF16 support to APFloat?
|
|
if (type.isBF16() || type.isF64())
|
|
return KeyTy(type, APFloat(value));
|
|
|
|
// This handles, e.g., F16 because there is no APFloat constructor for it.
|
|
bool unused;
|
|
APFloat val(value);
|
|
val.convert(type.cast<FloatType>().getFloatSemantics(),
|
|
APFloat::rmNearestTiesToEven, &unused);
|
|
return KeyTy(type, val);
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
const KeyTy &key) {
|
|
const auto &apint = key.second.bitcastToAPInt();
|
|
|
|
// Here one word's bitwidth equals to that of uint64_t.
|
|
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
|
|
|
auto byteSize =
|
|
FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
|
|
auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
|
|
auto result = ::new (rawMem) FloatAttributeStorage(
|
|
key.second.getSemantics(), key.first, elements.size());
|
|
std::uninitialized_copy(elements.begin(), elements.end(),
|
|
result->getTrailingObjects<uint64_t>());
|
|
return result;
|
|
}
|
|
|
|
/// Returns an APFloat representing the stored value.
|
|
APFloat getValue() const {
|
|
auto val = APInt(APFloat::getSizeInBits(semantics),
|
|
{getTrailingObjects<uint64_t>(), numObjects});
|
|
return APFloat(semantics, val);
|
|
}
|
|
|
|
const llvm::fltSemantics &semantics;
|
|
size_t numObjects;
|
|
};
|
|
|
|
/// An attribute representing an integral value.
|
|
struct IntegerAttributeStorage final
|
|
: public AttributeStorage,
|
|
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
|
|
using KeyTy = std::pair<Type, APInt>;
|
|
|
|
IntegerAttributeStorage(Type type, size_t numObjects)
|
|
: AttributeStorage(type), numObjects(numObjects) {
|
|
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
|
|
}
|
|
|
|
/// Key equality and hash functions.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == KeyTy(getType(), getValue());
|
|
}
|
|
static unsigned hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static IntegerAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
|
|
Type type;
|
|
APInt value;
|
|
std::tie(type, value) = key;
|
|
|
|
auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
|
|
auto size =
|
|
IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
|
|
auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
|
|
auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
|
|
std::uninitialized_copy(elements.begin(), elements.end(),
|
|
result->getTrailingObjects<uint64_t>());
|
|
return result;
|
|
}
|
|
|
|
/// Returns an APInt representing the stored value.
|
|
APInt getValue() const {
|
|
if (getType().isIndex())
|
|
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
|
|
return APInt(getType().getIntOrFloatBitWidth(),
|
|
{getTrailingObjects<uint64_t>(), numObjects});
|
|
}
|
|
|
|
size_t numObjects;
|
|
};
|
|
|
|
// An attribute representing a reference to an integer set.
|
|
struct IntegerSetAttributeStorage : public AttributeStorage {
|
|
using KeyTy = IntegerSet;
|
|
|
|
IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const { return key == value; }
|
|
|
|
/// Construct a new storage instance.
|
|
static IntegerSetAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
return new (allocator.allocate<IntegerSetAttributeStorage>())
|
|
IntegerSetAttributeStorage(key);
|
|
}
|
|
|
|
IntegerSet value;
|
|
};
|
|
|
|
/// Opaque Attribute Storage and Uniquing.
|
|
struct OpaqueAttributeStorage : public AttributeStorage {
|
|
OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData,
|
|
Type type)
|
|
: AttributeStorage(type), dialectNamespace(dialectNamespace),
|
|
attrData(attrData) {}
|
|
|
|
/// The hash key used for uniquing.
|
|
using KeyTy = std::tuple<Identifier, StringRef, Type>;
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == KeyTy(dialectNamespace, attrData, getType());
|
|
}
|
|
|
|
static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
const KeyTy &key) {
|
|
return new (allocator.allocate<OpaqueAttributeStorage>())
|
|
OpaqueAttributeStorage(std::get<0>(key),
|
|
allocator.copyInto(std::get<1>(key)),
|
|
std::get<2>(key));
|
|
}
|
|
|
|
// The dialect namespace.
|
|
Identifier dialectNamespace;
|
|
|
|
// The parser attribute data for this opaque attribute.
|
|
StringRef attrData;
|
|
};
|
|
|
|
/// An attribute representing a string value.
|
|
struct StringAttributeStorage : public AttributeStorage {
|
|
using KeyTy = std::pair<StringRef, Type>;
|
|
|
|
StringAttributeStorage(StringRef value, Type type)
|
|
: AttributeStorage(type), value(value) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == KeyTy(value, getType());
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
const KeyTy &key) {
|
|
return new (allocator.allocate<StringAttributeStorage>())
|
|
StringAttributeStorage(allocator.copyInto(key.first), key.second);
|
|
}
|
|
|
|
StringRef value;
|
|
};
|
|
|
|
/// An attribute representing a symbol reference.
|
|
struct SymbolRefAttributeStorage final
|
|
: public AttributeStorage,
|
|
public llvm::TrailingObjects<SymbolRefAttributeStorage,
|
|
FlatSymbolRefAttr> {
|
|
using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
|
|
|
|
SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
|
|
: value(value), numNestedRefs(numNestedRefs) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == KeyTy(value, getNestedRefs());
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static SymbolRefAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
|
|
auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
|
|
key.second.size());
|
|
auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
|
|
auto result = ::new (rawMem) SymbolRefAttributeStorage(
|
|
allocator.copyInto(key.first), key.second.size());
|
|
std::uninitialized_copy(key.second.begin(), key.second.end(),
|
|
result->getTrailingObjects<FlatSymbolRefAttr>());
|
|
return result;
|
|
}
|
|
|
|
/// Returns the set of nested references.
|
|
ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
|
|
return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
|
|
}
|
|
|
|
StringRef value;
|
|
size_t numNestedRefs;
|
|
};
|
|
|
|
/// An attribute representing a reference to a type.
|
|
struct TypeAttributeStorage : public AttributeStorage {
|
|
using KeyTy = Type;
|
|
|
|
TypeAttributeStorage(Type value) : value(value) {}
|
|
|
|
/// Key equality function.
|
|
bool operator==(const KeyTy &key) const { return key == value; }
|
|
|
|
/// Construct a new storage instance.
|
|
static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
|
KeyTy key) {
|
|
return new (allocator.allocate<TypeAttributeStorage>())
|
|
TypeAttributeStorage(key);
|
|
}
|
|
|
|
Type value;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Elements Attributes
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return the bit width which DenseElementsAttr should use for this type.
|
|
inline size_t getDenseElementBitWidth(Type eltType) {
|
|
// Align the width for complex to 8 to make storage and interpretation easier.
|
|
if (ComplexType comp = eltType.dyn_cast<ComplexType>())
|
|
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
|
|
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
|
// with double semantics.
|
|
if (eltType.isBF16())
|
|
return 64;
|
|
if (eltType.isIndex())
|
|
return IndexType::kInternalStorageBitWidth;
|
|
return eltType.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
/// An attribute representing a reference to a dense vector or tensor object.
|
|
struct DenseElementsAttributeStorage : public AttributeStorage {
|
|
public:
|
|
DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
|
|
: AttributeStorage(ty), isSplat(isSplat) {}
|
|
|
|
bool isSplat;
|
|
};
|
|
|
|
/// An attribute representing a reference to a dense vector or tensor object.
|
|
struct DenseIntOrFPElementsAttributeStorage
|
|
: public DenseElementsAttributeStorage {
|
|
DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
|
|
bool isSplat = false)
|
|
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
|
|
|
|
struct KeyTy {
|
|
KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
|
|
bool isSplat = false)
|
|
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
|
|
|
|
/// The type of the dense elements.
|
|
ShapedType type;
|
|
|
|
/// The raw buffer for the data storage.
|
|
ArrayRef<char> data;
|
|
|
|
/// The computed hash code for the storage data.
|
|
llvm::hash_code hashCode;
|
|
|
|
/// A boolean that indicates if this data is a splat or not.
|
|
bool isSplat;
|
|
};
|
|
|
|
/// Compare this storage instance with the provided key.
|
|
bool operator==(const KeyTy &key) const {
|
|
if (key.type != getType())
|
|
return false;
|
|
|
|
// For boolean splats we need to explicitly check that the first bit is the
|
|
// same. Boolean values are packed at the bit level, and even though a splat
|
|
// is detected the rest of the bits in the first byte may differ from the
|
|
// splat value.
|
|
if (key.type.getElementType().isInteger(1)) {
|
|
if (key.isSplat != isSplat)
|
|
return false;
|
|
if (isSplat)
|
|
return (key.data.front() & 1) == data.front();
|
|
}
|
|
|
|
// Otherwise, we can default to just checking the data.
|
|
return key.data == data;
|
|
}
|
|
|
|
/// Construct a key from a shaped type, raw data buffer, and a flag that
|
|
/// signals if the data is already known to be a splat. Callers to this
|
|
/// function are expected to tag preknown splat values when possible, e.g. one
|
|
/// element shapes.
|
|
static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
|
|
// Handle an empty storage instance.
|
|
if (data.empty())
|
|
return KeyTy(ty, data, 0);
|
|
|
|
// If the data is already known to be a splat, the key hash value is
|
|
// directly the data buffer.
|
|
if (isKnownSplat)
|
|
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
|
|
|
|
// Otherwise, we need to check if the data corresponds to a splat or not.
|
|
|
|
// Handle the simple case of only one element.
|
|
size_t numElements = ty.getNumElements();
|
|
assert(numElements != 1 && "splat of 1 element should already be detected");
|
|
|
|
// Handle boolean values directly as they are packed to 1-bit.
|
|
if (ty.getElementType().isInteger(1) == 1)
|
|
return getKeyForBoolData(ty, data, numElements);
|
|
|
|
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
|
|
// Non 1-bit dense elements are padded to 8-bits.
|
|
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
|
|
assert(((data.size() / storageSize) == numElements) &&
|
|
"data does not hold expected number of elements");
|
|
|
|
// Create the initial hash value with just the first element.
|
|
auto firstElt = data.take_front(storageSize);
|
|
auto hashVal = llvm::hash_value(firstElt);
|
|
|
|
// Check to see if this storage represents a splat. If it doesn't then
|
|
// combine the hash for the data starting with the first non splat element.
|
|
for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
|
|
if (memcmp(data.data(), &data[i], storageSize))
|
|
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
|
|
|
|
// Otherwise, this is a splat so just return the hash of the first element.
|
|
return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
|
|
}
|
|
|
|
/// Construct a key with a set of boolean data.
|
|
static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
|
|
size_t numElements) {
|
|
ArrayRef<char> splatData = data;
|
|
bool splatValue = splatData.front() & 1;
|
|
|
|
// Helper functor to generate a KeyTy for a boolean splat value.
|
|
auto generateSplatKey = [=] {
|
|
return KeyTy(ty, data.take_front(1),
|
|
llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
|
|
/*isSplat=*/true);
|
|
};
|
|
|
|
// Handle the case where the potential splat value is 1 and the number of
|
|
// elements is non 8-bit aligned.
|
|
size_t numOddElements = numElements % CHAR_BIT;
|
|
if (splatValue && numOddElements != 0) {
|
|
// Check that all bits are set in the last value.
|
|
char lastElt = splatData.back();
|
|
if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
|
|
return KeyTy(ty, data, llvm::hash_value(data));
|
|
|
|
// If this is the only element, the data is known to be a splat.
|
|
if (splatData.size() == 1)
|
|
return generateSplatKey();
|
|
splatData = splatData.drop_back();
|
|
}
|
|
|
|
// Check that the data buffer corresponds to a splat of the proper mask.
|
|
char mask = splatValue ? ~0 : 0;
|
|
return llvm::all_of(splatData, [mask](char c) { return c == mask; })
|
|
? generateSplatKey()
|
|
: KeyTy(ty, data, llvm::hash_value(data));
|
|
}
|
|
|
|
/// Hash the key for the storage.
|
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(key.type, key.hashCode);
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static DenseIntOrFPElementsAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
// If the data buffer is non-empty, we copy it into the allocator with a
|
|
// 64-bit alignment.
|
|
ArrayRef<char> copy, data = key.data;
|
|
if (!data.empty()) {
|
|
char *rawData = reinterpret_cast<char *>(
|
|
allocator.allocate(data.size(), alignof(uint64_t)));
|
|
std::memcpy(rawData, data.data(), data.size());
|
|
|
|
// If this is a boolean splat, make sure only the first bit is used.
|
|
if (key.isSplat && key.type.getElementType().isInteger(1))
|
|
rawData[0] &= 1;
|
|
copy = ArrayRef<char>(rawData, data.size());
|
|
}
|
|
|
|
return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
|
|
DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
|
|
}
|
|
|
|
ArrayRef<char> data;
|
|
};
|
|
|
|
/// An attribute representing a reference to a dense vector or tensor object
|
|
/// containing strings.
|
|
struct DenseStringElementsAttributeStorage
|
|
: public DenseElementsAttributeStorage {
|
|
DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
|
|
bool isSplat = false)
|
|
: DenseElementsAttributeStorage(ty, isSplat), data(data) {}
|
|
|
|
struct KeyTy {
|
|
KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
|
|
bool isSplat = false)
|
|
: type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
|
|
|
|
/// The type of the dense elements.
|
|
ShapedType type;
|
|
|
|
/// The raw buffer for the data storage.
|
|
ArrayRef<StringRef> data;
|
|
|
|
/// The computed hash code for the storage data.
|
|
llvm::hash_code hashCode;
|
|
|
|
/// A boolean that indicates if this data is a splat or not.
|
|
bool isSplat;
|
|
};
|
|
|
|
/// Compare this storage instance with the provided key.
|
|
bool operator==(const KeyTy &key) const {
|
|
if (key.type != getType())
|
|
return false;
|
|
|
|
// Otherwise, we can default to just checking the data. StringRefs compare
|
|
// by contents.
|
|
return key.data == data;
|
|
}
|
|
|
|
/// Construct a key from a shaped type, StringRef data buffer, and a flag that
|
|
/// signals if the data is already known to be a splat. Callers to this
|
|
/// function are expected to tag preknown splat values when possible, e.g. one
|
|
/// element shapes.
|
|
static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
|
|
bool isKnownSplat) {
|
|
// Handle an empty storage instance.
|
|
if (data.empty())
|
|
return KeyTy(ty, data, 0);
|
|
|
|
// If the data is already known to be a splat, the key hash value is
|
|
// directly the data buffer.
|
|
if (isKnownSplat)
|
|
return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
|
|
|
|
// Handle the simple case of only one element.
|
|
assert(ty.getNumElements() != 1 &&
|
|
"splat of 1 element should already be detected");
|
|
|
|
// Create the initial hash value with just the first element.
|
|
const auto &firstElt = data.front();
|
|
auto hashVal = llvm::hash_value(firstElt);
|
|
|
|
// Check to see if this storage represents a splat. If it doesn't then
|
|
// combine the hash for the data starting with the first non splat element.
|
|
for (size_t i = 1, e = data.size(); i != e; i++)
|
|
if (!firstElt.equals(data[i]))
|
|
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
|
|
|
|
// Otherwise, this is a splat so just return the hash of the first element.
|
|
return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
|
|
}
|
|
|
|
/// Hash the key for the storage.
|
|
static llvm::hash_code hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(key.type, key.hashCode);
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static DenseStringElementsAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
// If the data buffer is non-empty, we copy it into the allocator with a
|
|
// 64-bit alignment.
|
|
ArrayRef<StringRef> copy, data = key.data;
|
|
if (data.empty()) {
|
|
return new (allocator.allocate<DenseStringElementsAttributeStorage>())
|
|
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
|
|
}
|
|
|
|
int numEntries = key.isSplat ? 1 : data.size();
|
|
|
|
// Compute the amount data needed to store the ArrayRef and StringRef
|
|
// contents.
|
|
size_t dataSize = sizeof(StringRef) * numEntries;
|
|
for (int i = 0; i < numEntries; i++)
|
|
dataSize += data[i].size();
|
|
|
|
char *rawData = reinterpret_cast<char *>(
|
|
allocator.allocate(dataSize, alignof(uint64_t)));
|
|
|
|
// Setup a mutable array ref of our string refs so that we can update their
|
|
// contents.
|
|
auto mutableCopy = MutableArrayRef<StringRef>(
|
|
reinterpret_cast<StringRef *>(rawData), numEntries);
|
|
auto stringData = rawData + numEntries * sizeof(StringRef);
|
|
|
|
for (int i = 0; i < numEntries; i++) {
|
|
memcpy(stringData, data[i].data(), data[i].size());
|
|
mutableCopy[i] = StringRef(stringData, data[i].size());
|
|
stringData += data[i].size();
|
|
}
|
|
|
|
copy =
|
|
ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
|
|
|
|
return new (allocator.allocate<DenseStringElementsAttributeStorage>())
|
|
DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
|
|
}
|
|
|
|
ArrayRef<StringRef> data;
|
|
};
|
|
|
|
/// An attribute representing a reference to a tensor constant with opaque
|
|
/// content.
|
|
struct OpaqueElementsAttributeStorage : public AttributeStorage {
|
|
using KeyTy = std::tuple<Type, Dialect *, StringRef>;
|
|
|
|
OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
|
|
: AttributeStorage(type), dialect(dialect), bytes(bytes) {}
|
|
|
|
/// Key equality and hash functions.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == std::make_tuple(getType(), dialect, bytes);
|
|
}
|
|
static unsigned hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
|
|
std::get<2>(key));
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static OpaqueElementsAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
// TODO(b/131468830): Provide a way to avoid copying content of large opaque
|
|
// tensors This will likely require a new reference attribute kind.
|
|
return new (allocator.allocate<OpaqueElementsAttributeStorage>())
|
|
OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
|
|
allocator.copyInto(std::get<2>(key)));
|
|
}
|
|
|
|
Dialect *dialect;
|
|
StringRef bytes;
|
|
};
|
|
|
|
/// An attribute representing a reference to a sparse vector or tensor object.
|
|
struct SparseElementsAttributeStorage : public AttributeStorage {
|
|
using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
|
|
|
|
SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
|
|
DenseElementsAttr values)
|
|
: AttributeStorage(type), indices(indices), values(values) {}
|
|
|
|
/// Key equality and hash functions.
|
|
bool operator==(const KeyTy &key) const {
|
|
return key == std::make_tuple(getType(), indices, values);
|
|
}
|
|
static unsigned hashKey(const KeyTy &key) {
|
|
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
|
|
std::get<2>(key));
|
|
}
|
|
|
|
/// Construct a new storage instance.
|
|
static SparseElementsAttributeStorage *
|
|
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
|
return new (allocator.allocate<SparseElementsAttributeStorage>())
|
|
SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
|
|
std::get<2>(key));
|
|
}
|
|
|
|
DenseIntElementsAttr indices;
|
|
DenseElementsAttr values;
|
|
};
|
|
} // namespace detail
|
|
} // namespace mlir
|
|
|
|
#endif // ATTRIBUTEDETAIL_H_
|