[mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances

This allows for storage instances to store data that isn't uniqued in the context, or contain otherwise non-trivial logic, in the rare situations that they occur. Storage instances with trivial destructors will still have their destructor skipped. A consequence of this is that the storage instance definition must be visible from the place that registers the type.

Differential Revision: https://reviews.llvm.org/D98311
This commit is contained in:
River Riddle 2021-03-11 11:24:43 -08:00
parent dc9c09632f
commit 31bb8efd69
30 changed files with 309 additions and 111 deletions

View File

@ -32,6 +32,12 @@ public:
mlir::Type type) const override;
void printAttribute(mlir::Attribute attr,
mlir::DialectAsmPrinter &p) const override;
private:
// Register the Attributes of this dialect.
void registerAttributes();
// Register the Types of this dialect.
void registerTypes();
};
} // namespace fir

View File

@ -243,3 +243,12 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
os << "<(unknown attribute)>";
}
}
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
}

View File

@ -19,13 +19,8 @@ using namespace fir;
fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
: mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
PointerType, RealType, RecordType, ReferenceType, SequenceType,
ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
fir::VectorType>();
addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
registerTypes();
registerAttributes();
addOperations<
#define GET_OP_LIST
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"

View File

@ -866,3 +866,15 @@ mlir::LogicalResult fir::VectorType::verify(
bool fir::VectorType::isValidElementType(mlir::Type t) {
return isa_real(t) || isa_integer(t);
}
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
void FIROpsDialect::registerTypes() {
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
PointerType, RealType, RecordType, ReferenceType, SequenceType,
ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
fir::VectorType>();
}

View File

@ -319,7 +319,9 @@ public:
Once the dialect types have been defined, they must then be registered with a
`Dialect`. This is done via a similar mechanism to
[operations](LangRef.md#operations), with the `addTypes` method.
[operations](LangRef.md#operations), with the `addTypes` method. The one
distinct difference with operations, is that when a type is registered the
definition of its storage class must be visible.
```c++
struct MyDialect : public Dialect {

View File

@ -187,6 +187,9 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
}
```
(An important note here is that when registering a type, the definition of the
storage class must be visible.)
With this we can now use our `StructType` when generating MLIR from Toy. See
examples/toy/Ch7/mlir/MLIRGen.cpp for more details.

View File

@ -76,33 +76,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
}
};
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
addInterfaces<ToyInlinerInterface>();
addTypes<StructType>();
}
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
if (type.isa<StructType>())
return builder.create<StructConstantOp>(loc, type,
value.cast<mlir::ArrayAttr>());
return builder.create<ConstantOp>(loc, type,
value.cast<mlir::DenseElementsAttr>());
}
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
@ -566,3 +539,30 @@ void ToyDialect::printType(mlir::Type type,
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
addInterfaces<ToyInlinerInterface>();
addTypes<StructType>();
}
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
if (type.isa<StructType>())
return builder.create<StructConstantOp>(loc, type,
value.cast<mlir::ArrayAttr>());
return builder.create<ConstantOp>(loc, type,
value.cast<mlir::DenseElementsAttr>());
}

View File

@ -64,6 +64,9 @@ def PDL_Dialect : Dialect {
let name = "pdl";
let cppNamespace = "::mlir::pdl";
let extraClassDeclaration = [{
void registerTypes();
}];
}
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT

View File

@ -52,6 +52,9 @@ def SPIRV_Dialect : Dialect {
let hasRegionResultAttrVerify = 1;
let extraClassDeclaration = [{
void registerAttributes();
void registerTypes();
//===------------------------------------------------------------------===//
// Attribute
//===------------------------------------------------------------------===//

View File

@ -22,6 +22,17 @@ def Builtin_Dialect : Dialect {
let name = "";
let cppNamespace = "::mlir";
let extraClassDeclaration = [{
private:
// Register the builtin Attributes.
void registerAttributes();
// Register the builtin Location Attributes.
void registerLocationAttributes();
// Register the builtin Types.
void registerTypes();
public:
}];
}
#endif // BUILTIN_BASE

View File

@ -135,7 +135,13 @@ public:
/// instances of this class type. `id` is the type identifier that will be
/// used to identify this type when creating instances of it via 'get'.
template <typename Storage> void registerParametricStorageType(TypeID id) {
registerParametricStorageTypeImpl(id);
// If the storage is trivially destructible, we don't need a destructor
// function.
if (std::is_trivially_destructible<Storage>::value)
return registerParametricStorageTypeImpl(id, nullptr);
registerParametricStorageTypeImpl(id, [](BaseStorage *storage) {
static_cast<Storage *>(storage)->~Storage();
});
}
/// Utility override when the storage type represents the type id.
template <typename Storage> void registerParametricStorageType() {
@ -244,8 +250,10 @@ private:
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
/// Implementation for registering an instance of a derived type with
/// parametric storage.
void registerParametricStorageTypeImpl(TypeID id);
/// parametric storage. This method takes an optional destructor function that
/// destructs storage instances when necessary.
void registerParametricStorageTypeImpl(
TypeID id, function_ref<void(BaseStorage *)> destructorFn);
/// Implementation for getting an instance of a derived type with default
/// storage.

View File

@ -20,6 +20,12 @@
using namespace mlir;
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
void arm_sve::ArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
@ -31,12 +37,6 @@ void arm_sve::ArmSVEDialect::initialize() {
>();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
//===----------------------------------------------------------------------===//
// ScalableVectorType
//===----------------------------------------------------------------------===//

View File

@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "TypeDetail.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"

View File

@ -25,10 +25,7 @@ void PDLDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
>();
registerTypes();
}
/// Returns true if the given operation is used by a "binding" pdl operation

View File

@ -26,6 +26,13 @@ using namespace mlir::pdl;
// PDLDialect
//===----------------------------------------------------------------------===//
void PDLDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
>();
}
static Type parsePDLType(DialectAsmParser &parser) {
StringRef typeTag;
if (parser.parseKeyword(&typeTag))

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
@ -350,3 +351,11 @@ spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void spirv::SPIRVDialect::registerAttributes() {
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
}

View File

@ -115,10 +115,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void SPIRVDialect::initialize() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, SampledImageType, StructType>();
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
registerAttributes();
registerTypes();
// Add SPIR-V ops.
addOperations<

View File

@ -1154,3 +1154,12 @@ void MatrixType::getCapabilities(
// Add any capabilities associated with the underlying vectors (i.e., columns)
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, SampledImageType, StructType>();
}

View File

@ -100,36 +100,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
return false;
}
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
void VectorDialect::initialize() {
addAttributes<CombiningKindAttr>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
>();
}
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<ConstantOp>(loc, type, value);
}
IntegerType vector::getVectorSubscriptType(Builder &builder) {
return builder.getIntegerType(64);
}
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
ArrayRef<int64_t> values) {
return builder.getI64ArrayAttr(values);
}
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
@ -230,6 +200,36 @@ void VectorDialect::printAttribute(Attribute attr,
llvm_unreachable("Unknown attribute type");
}
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
void VectorDialect::initialize() {
addAttributes<CombiningKindAttr>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
>();
}
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<ConstantOp>(loc, type, value);
}
IntegerType vector::getVectorSubscriptType(Builder &builder) {
return builder.getIntegerType(64);
}
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
ArrayRef<int64_t> values) {
return builder.getI64ArrayAttr(values);
}
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//

View File

@ -9,6 +9,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
@ -28,6 +29,18 @@ using namespace mlir::detail;
#define GET_ATTRDEF_CLASSES
#include "mlir/IR/BuiltinAttributes.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerAttributes() {
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
UnitAttr>();
}
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//

View File

@ -60,17 +60,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
} // end anonymous namespace.
void BuiltinDialect::initialize() {
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
UnitAttr>();
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();
registerTypes();
registerAttributes();
registerLocationAttributes();
addOperations<
#define GET_OP_LIST
#include "mlir/IR/BuiltinOps.cpp.inc"

View File

@ -30,6 +30,17 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerTypes() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/IR/BuiltinTypes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
@ -514,7 +525,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
if (!BaseMemRefType::isValidElementType(elementType))
return emitError() << "invalid memref element type";
// Negative sizes are not allowed except for `-1` that means dynamic size.
// Negative sizes are not allowed except for `-1` that means dynamic size.
for (int64_t s : shape)
if (s < -1)
return emitError() << "invalid memref size";

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Location.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Identifier.h"
#include "llvm/ADT/SetVector.h"
@ -20,6 +21,17 @@ using namespace mlir::detail;
#define GET_ATTRDEF_CLASSES
#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerLocationAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// LocationAttr
//===----------------------------------------------------------------------===//

View File

@ -100,12 +100,23 @@ private:
return storage;
}
/// Destroy all of the storage instances within the given shard.
void destroyShardInstances(Shard &shard) {
if (!destructorFn)
return;
for (HashedStorage &instance : shard.instances)
destructorFn(instance.storage);
}
public:
#if LLVM_ENABLE_THREADS != 0
/// Initialize the storage uniquer with a given number of storage shards to
/// use. The provided shard number is required to be a valid power of 2.
ParametricStorageUniquer(size_t numShards = 8)
: shards(new std::atomic<Shard *>[numShards]), numShards(numShards) {
/// use. The provided shard number is required to be a valid power of 2. The
/// destructor function is used to destroy any allocated storage instances.
ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
size_t numShards = 8)
: shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
destructorFn(destructorFn) {
assert(llvm::isPowerOf2_64(numShards) &&
"the number of shards is required to be a power of 2");
for (size_t i = 0; i < numShards; i++)
@ -113,9 +124,12 @@ public:
}
~ParametricStorageUniquer() {
// Free all of the allocated shards.
for (size_t i = 0; i != numShards; ++i)
if (Shard *shard = shards[i].load())
for (size_t i = 0; i != numShards; ++i) {
if (Shard *shard = shards[i].load()) {
destroyShardInstances(*shard);
delete shard;
}
}
}
/// Get or create an instance of a parametric type.
BaseStorage *
@ -204,10 +218,17 @@ private:
/// The number of available shards.
size_t numShards;
/// Function to used to destruct any allocated storage instances.
function_ref<void(BaseStorage *)> destructorFn;
#else
/// If multi-threading is disabled, ignore the shard parameter as we will
/// always use one shard.
ParametricStorageUniquer(size_t numShards = 0) {}
/// always use one shard. The destructor function is used to destroy any
/// allocated storage instances.
ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
size_t numShards = 0)
: destructorFn(destructorFn) {}
~ParametricStorageUniquer() { destroyShardInstances(shard); }
/// Get or create an instance of a parametric type.
BaseStorage *
@ -228,6 +249,9 @@ private:
private:
/// The main uniquer shard that is used for allocating storage instances.
Shard shard;
/// Function to used to destruct any allocated storage instances.
function_ref<void(BaseStorage *)> destructorFn;
#endif
};
} // end anonymous namespace
@ -323,9 +347,10 @@ auto StorageUniquer::getParametricStorageTypeImpl(
/// Implementation for registering an instance of a derived type with
/// parametric storage.
void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
void StorageUniquer::registerParametricStorageTypeImpl(
TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
impl->parametricUniquers.try_emplace(
id, std::make_unique<ParametricStorageUniquer>());
id, std::make_unique<ParametricStorageUniquer>(destructorFn));
}
/// Implementation for getting an instance of a derived type with default

View File

@ -100,6 +100,13 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
// TestDialect
//===----------------------------------------------------------------------===//
void TestDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "TestAttrDefs.cpp.inc"
>();
}
Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrTag;

View File

@ -166,20 +166,14 @@ struct TestInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
void TestDialect::initialize() {
registerAttributes();
registerTypes();
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "TestAttrDefs.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
allowUnknownOperations();
}

View File

@ -32,6 +32,9 @@ def Test_Dialect : Dialect {
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{
void registerAttributes();
void registerTypes();
Attribute parseAttribute(DialectAsmParser &parser,
Type type) const override;
void printAttribute(Attribute attr,

View File

@ -164,6 +164,13 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
// TestDialect
//===----------------------------------------------------------------------===//
void TestDialect::registerTypes() {
addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
}
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;

View File

@ -3,6 +3,7 @@ add_mlir_unittest(MLIRSupportTests
DebugCounterTest.cpp
IndentedOstreamTest.cpp
MathExtrasTest.cpp
StorageUniquerTest.cpp
)
target_link_libraries(MLIRSupportTests

View File

@ -0,0 +1,60 @@
//===- StorageUniquerTest.cpp - StorageUniquer Tests ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/StorageUniquer.h"
#include "gmock/gmock.h"
using namespace mlir;
namespace {
/// Simple storage class used for testing.
template <typename ConcreteT, typename... Args>
struct SimpleStorage : public StorageUniquer::BaseStorage {
using Base = SimpleStorage<ConcreteT, Args...>;
using KeyTy = std::tuple<Args...>;
SimpleStorage(KeyTy key) : key(key) {}
/// Get an instance of this storage instance.
template <typename... ParamsT>
static ConcreteT *get(StorageUniquer &uniquer, ParamsT &&...params) {
return uniquer.get<ConcreteT>(
/*initFn=*/{}, std::make_tuple(std::forward<ParamsT>(params)...));
}
/// Construct an instance with the given storage allocator.
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return new (alloc.allocate<ConcreteT>())
ConcreteT(std::forward<KeyTy>(key));
}
bool operator==(const KeyTy &key) const { return this->key == key; }
KeyTy key;
};
} // namespace
TEST(StorageUniquerTest, NonTrivialDestructor) {
struct NonTrivialStorage : public SimpleStorage<NonTrivialStorage, bool *> {
using Base::Base;
~NonTrivialStorage() {
bool *wasDestructed = std::get<0>(key);
*wasDestructed = true;
}
};
// Verify that the storage instance destructor was properly called.
bool wasDestructed = false;
{
StorageUniquer uniquer;
uniquer.registerParametricStorageType<NonTrivialStorage>();
NonTrivialStorage::get(uniquer, &wasDestructed);
}
EXPECT_TRUE(wasDestructed);
}