forked from OSchip/llvm-project
[mlir][StorageUniquer] Properly call the destructor on non-trivially destructible storage instances
This allows for storage instances to store data that isn't uniqued in the context, or contain otherwise non-trivial logic, in the rare situations that they occur. Storage instances with trivial destructors will still have their destructor skipped. A consequence of this is that the storage instance definition must be visible from the place that registers the type. Differential Revision: https://reviews.llvm.org/D98311
This commit is contained in:
parent
dc9c09632f
commit
31bb8efd69
|
@ -32,6 +32,12 @@ public:
|
|||
mlir::Type type) const override;
|
||||
void printAttribute(mlir::Attribute attr,
|
||||
mlir::DialectAsmPrinter &p) const override;
|
||||
|
||||
private:
|
||||
// Register the Attributes of this dialect.
|
||||
void registerAttributes();
|
||||
// Register the Types of this dialect.
|
||||
void registerTypes();
|
||||
};
|
||||
|
||||
} // namespace fir
|
||||
|
|
|
@ -243,3 +243,12 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
|
|||
os << "<(unknown attribute)>";
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FIROpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FIROpsDialect::registerAttributes() {
|
||||
addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
|
||||
PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
|
||||
}
|
||||
|
|
|
@ -19,13 +19,8 @@ using namespace fir;
|
|||
|
||||
fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
|
||||
: mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
|
||||
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
|
||||
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
|
||||
PointerType, RealType, RecordType, ReferenceType, SequenceType,
|
||||
ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
|
||||
fir::VectorType>();
|
||||
addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
|
||||
PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
|
||||
registerTypes();
|
||||
registerAttributes();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
|
||||
|
|
|
@ -866,3 +866,15 @@ mlir::LogicalResult fir::VectorType::verify(
|
|||
bool fir::VectorType::isValidElementType(mlir::Type t) {
|
||||
return isa_real(t) || isa_integer(t);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FIROpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FIROpsDialect::registerTypes() {
|
||||
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
|
||||
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
|
||||
PointerType, RealType, RecordType, ReferenceType, SequenceType,
|
||||
ShapeType, ShapeShiftType, ShiftType, SliceType, TypeDescType,
|
||||
fir::VectorType>();
|
||||
}
|
||||
|
|
|
@ -319,7 +319,9 @@ public:
|
|||
|
||||
Once the dialect types have been defined, they must then be registered with a
|
||||
`Dialect`. This is done via a similar mechanism to
|
||||
[operations](LangRef.md#operations), with the `addTypes` method.
|
||||
[operations](LangRef.md#operations), with the `addTypes` method. The one
|
||||
distinct difference with operations, is that when a type is registered the
|
||||
definition of its storage class must be visible.
|
||||
|
||||
```c++
|
||||
struct MyDialect : public Dialect {
|
||||
|
|
|
@ -187,6 +187,9 @@ ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
|
|||
}
|
||||
```
|
||||
|
||||
(An important note here is that when registering a type, the definition of the
|
||||
storage class must be visible.)
|
||||
|
||||
With this we can now use our `StructType` when generating MLIR from Toy. See
|
||||
examples/toy/Ch7/mlir/MLIRGen.cpp for more details.
|
||||
|
||||
|
|
|
@ -76,33 +76,6 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
|
||||
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "toy/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ToyInlinerInterface>();
|
||||
addTypes<StructType>();
|
||||
}
|
||||
|
||||
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::Location loc) {
|
||||
if (type.isa<StructType>())
|
||||
return builder.create<StructConstantOp>(loc, type,
|
||||
value.cast<mlir::ArrayAttr>());
|
||||
return builder.create<ConstantOp>(loc, type,
|
||||
value.cast<mlir::DenseElementsAttr>());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -566,3 +539,30 @@ void ToyDialect::printType(mlir::Type type,
|
|||
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
|
||||
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "toy/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ToyInlinerInterface>();
|
||||
addTypes<StructType>();
|
||||
}
|
||||
|
||||
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::Location loc) {
|
||||
if (type.isa<StructType>())
|
||||
return builder.create<StructConstantOp>(loc, type,
|
||||
value.cast<mlir::ArrayAttr>());
|
||||
return builder.create<ConstantOp>(loc, type,
|
||||
value.cast<mlir::DenseElementsAttr>());
|
||||
}
|
||||
|
|
|
@ -64,6 +64,9 @@ def PDL_Dialect : Dialect {
|
|||
|
||||
let name = "pdl";
|
||||
let cppNamespace = "::mlir::pdl";
|
||||
let extraClassDeclaration = [{
|
||||
void registerTypes();
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT
|
||||
|
|
|
@ -52,6 +52,9 @@ def SPIRV_Dialect : Dialect {
|
|||
let hasRegionResultAttrVerify = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerAttributes();
|
||||
void registerTypes();
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Attribute
|
||||
//===------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,6 +22,17 @@ def Builtin_Dialect : Dialect {
|
|||
|
||||
let name = "";
|
||||
let cppNamespace = "::mlir";
|
||||
let extraClassDeclaration = [{
|
||||
private:
|
||||
// Register the builtin Attributes.
|
||||
void registerAttributes();
|
||||
// Register the builtin Location Attributes.
|
||||
void registerLocationAttributes();
|
||||
// Register the builtin Types.
|
||||
void registerTypes();
|
||||
|
||||
public:
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // BUILTIN_BASE
|
||||
|
|
|
@ -135,7 +135,13 @@ public:
|
|||
/// instances of this class type. `id` is the type identifier that will be
|
||||
/// used to identify this type when creating instances of it via 'get'.
|
||||
template <typename Storage> void registerParametricStorageType(TypeID id) {
|
||||
registerParametricStorageTypeImpl(id);
|
||||
// If the storage is trivially destructible, we don't need a destructor
|
||||
// function.
|
||||
if (std::is_trivially_destructible<Storage>::value)
|
||||
return registerParametricStorageTypeImpl(id, nullptr);
|
||||
registerParametricStorageTypeImpl(id, [](BaseStorage *storage) {
|
||||
static_cast<Storage *>(storage)->~Storage();
|
||||
});
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage> void registerParametricStorageType() {
|
||||
|
@ -244,8 +250,10 @@ private:
|
|||
function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
|
||||
|
||||
/// Implementation for registering an instance of a derived type with
|
||||
/// parametric storage.
|
||||
void registerParametricStorageTypeImpl(TypeID id);
|
||||
/// parametric storage. This method takes an optional destructor function that
|
||||
/// destructs storage instances when necessary.
|
||||
void registerParametricStorageTypeImpl(
|
||||
TypeID id, function_ref<void(BaseStorage *)> destructorFn);
|
||||
|
||||
/// Implementation for getting an instance of a derived type with default
|
||||
/// storage.
|
||||
|
|
|
@ -20,6 +20,12 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
|
||||
|
||||
void arm_sve::ArmSVEDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
@ -31,12 +37,6 @@ void arm_sve::ArmSVEDialect::initialize() {
|
|||
>();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalableVectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
|
|
@ -25,10 +25,7 @@ void PDLDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
|
||||
>();
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
|
||||
>();
|
||||
registerTypes();
|
||||
}
|
||||
|
||||
/// Returns true if the given operation is used by a "binding" pdl operation
|
||||
|
|
|
@ -26,6 +26,13 @@ using namespace mlir::pdl;
|
|||
// PDLDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PDLDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
static Type parsePDLType(DialectAsmParser &parser) {
|
||||
StringRef typeTag;
|
||||
if (parser.parseKeyword(&typeTag))
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
|
@ -350,3 +351,11 @@ spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void spirv::SPIRVDialect::registerAttributes() {
|
||||
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
|
||||
}
|
||||
|
|
|
@ -115,10 +115,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SPIRVDialect::initialize() {
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
|
||||
PointerType, RuntimeArrayType, SampledImageType, StructType>();
|
||||
|
||||
addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
|
||||
registerAttributes();
|
||||
registerTypes();
|
||||
|
||||
// Add SPIR-V ops.
|
||||
addOperations<
|
||||
|
|
|
@ -1154,3 +1154,12 @@ void MatrixType::getCapabilities(
|
|||
// Add any capabilities associated with the underlying vectors (i.e., columns)
|
||||
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SPIRVDialect::registerTypes() {
|
||||
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
|
||||
PointerType, RuntimeArrayType, SampledImageType, StructType>();
|
||||
}
|
||||
|
|
|
@ -100,36 +100,6 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
|
|||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void VectorDialect::initialize() {
|
||||
addAttributes<CombiningKindAttr>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<ConstantOp>(loc, type, value);
|
||||
}
|
||||
|
||||
IntegerType vector::getVectorSubscriptType(Builder &builder) {
|
||||
return builder.getIntegerType(64);
|
||||
}
|
||||
|
||||
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
|
||||
ArrayRef<int64_t> values) {
|
||||
return builder.getI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CombiningKindAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -230,6 +200,36 @@ void VectorDialect::printAttribute(Attribute attr,
|
|||
llvm_unreachable("Unknown attribute type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void VectorDialect::initialize() {
|
||||
addAttributes<CombiningKindAttr>();
|
||||
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
/// Materialize a single constant operation from a given attribute value with
|
||||
/// the desired resultant type.
|
||||
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<ConstantOp>(loc, type, value);
|
||||
}
|
||||
|
||||
IntegerType vector::getVectorSubscriptType(Builder &builder) {
|
||||
return builder.getIntegerType(64);
|
||||
}
|
||||
|
||||
ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
|
||||
ArrayRef<int64_t> values) {
|
||||
return builder.getI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "AttributeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
|
@ -28,6 +29,18 @@ using namespace mlir::detail;
|
|||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/IR/BuiltinAttributes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuiltinDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BuiltinDialect::registerAttributes() {
|
||||
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
||||
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
||||
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
||||
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
|
||||
UnitAttr>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DictionaryAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -60,17 +60,9 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
|
|||
} // end anonymous namespace.
|
||||
|
||||
void BuiltinDialect::initialize() {
|
||||
addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
|
||||
Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
|
||||
MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
|
||||
RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
|
||||
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
||||
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
||||
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
||||
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
|
||||
UnitAttr>();
|
||||
addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
|
||||
UnknownLoc>();
|
||||
registerTypes();
|
||||
registerAttributes();
|
||||
registerLocationAttributes();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/IR/BuiltinOps.cpp.inc"
|
||||
|
|
|
@ -30,6 +30,17 @@ using namespace mlir::detail;
|
|||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/IR/BuiltinTypes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuiltinDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BuiltinDialect::registerTypes() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/IR/BuiltinTypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ComplexType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -514,7 +525,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
if (!BaseMemRefType::isValidElementType(elementType))
|
||||
return emitError() << "invalid memref element type";
|
||||
|
||||
// Negative sizes are not allowed except for `-1` that means dynamic size.
|
||||
// Negative sizes are not allowed except for `-1` that means dynamic size.
|
||||
for (int64_t s : shape)
|
||||
if (s < -1)
|
||||
return emitError() << "invalid memref size";
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
|
@ -20,6 +21,17 @@ using namespace mlir::detail;
|
|||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuiltinDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BuiltinDialect::registerLocationAttributes() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "mlir/IR/BuiltinLocationAttributes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LocationAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -100,12 +100,23 @@ private:
|
|||
return storage;
|
||||
}
|
||||
|
||||
/// Destroy all of the storage instances within the given shard.
|
||||
void destroyShardInstances(Shard &shard) {
|
||||
if (!destructorFn)
|
||||
return;
|
||||
for (HashedStorage &instance : shard.instances)
|
||||
destructorFn(instance.storage);
|
||||
}
|
||||
|
||||
public:
|
||||
#if LLVM_ENABLE_THREADS != 0
|
||||
/// Initialize the storage uniquer with a given number of storage shards to
|
||||
/// use. The provided shard number is required to be a valid power of 2.
|
||||
ParametricStorageUniquer(size_t numShards = 8)
|
||||
: shards(new std::atomic<Shard *>[numShards]), numShards(numShards) {
|
||||
/// use. The provided shard number is required to be a valid power of 2. The
|
||||
/// destructor function is used to destroy any allocated storage instances.
|
||||
ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
|
||||
size_t numShards = 8)
|
||||
: shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
|
||||
destructorFn(destructorFn) {
|
||||
assert(llvm::isPowerOf2_64(numShards) &&
|
||||
"the number of shards is required to be a power of 2");
|
||||
for (size_t i = 0; i < numShards; i++)
|
||||
|
@ -113,9 +124,12 @@ public:
|
|||
}
|
||||
~ParametricStorageUniquer() {
|
||||
// Free all of the allocated shards.
|
||||
for (size_t i = 0; i != numShards; ++i)
|
||||
if (Shard *shard = shards[i].load())
|
||||
for (size_t i = 0; i != numShards; ++i) {
|
||||
if (Shard *shard = shards[i].load()) {
|
||||
destroyShardInstances(*shard);
|
||||
delete shard;
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Get or create an instance of a parametric type.
|
||||
BaseStorage *
|
||||
|
@ -204,10 +218,17 @@ private:
|
|||
/// The number of available shards.
|
||||
size_t numShards;
|
||||
|
||||
/// Function to used to destruct any allocated storage instances.
|
||||
function_ref<void(BaseStorage *)> destructorFn;
|
||||
|
||||
#else
|
||||
/// If multi-threading is disabled, ignore the shard parameter as we will
|
||||
/// always use one shard.
|
||||
ParametricStorageUniquer(size_t numShards = 0) {}
|
||||
/// always use one shard. The destructor function is used to destroy any
|
||||
/// allocated storage instances.
|
||||
ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
|
||||
size_t numShards = 0)
|
||||
: destructorFn(destructorFn) {}
|
||||
~ParametricStorageUniquer() { destroyShardInstances(shard); }
|
||||
|
||||
/// Get or create an instance of a parametric type.
|
||||
BaseStorage *
|
||||
|
@ -228,6 +249,9 @@ private:
|
|||
private:
|
||||
/// The main uniquer shard that is used for allocating storage instances.
|
||||
Shard shard;
|
||||
|
||||
/// Function to used to destruct any allocated storage instances.
|
||||
function_ref<void(BaseStorage *)> destructorFn;
|
||||
#endif
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -323,9 +347,10 @@ auto StorageUniquer::getParametricStorageTypeImpl(
|
|||
|
||||
/// Implementation for registering an instance of a derived type with
|
||||
/// parametric storage.
|
||||
void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) {
|
||||
void StorageUniquer::registerParametricStorageTypeImpl(
|
||||
TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
|
||||
impl->parametricUniquers.try_emplace(
|
||||
id, std::make_unique<ParametricStorageUniquer>());
|
||||
id, std::make_unique<ParametricStorageUniquer>(destructorFn));
|
||||
}
|
||||
|
||||
/// Implementation for getting an instance of a derived type with default
|
||||
|
|
|
@ -100,6 +100,13 @@ void CompoundAAttr::print(DialectAsmPrinter &printer) const {
|
|||
// TestDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TestDialect::registerAttributes() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "TestAttrDefs.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
Attribute TestDialect::parseAttribute(DialectAsmParser &parser,
|
||||
Type type) const {
|
||||
StringRef attrTag;
|
||||
|
|
|
@ -166,20 +166,14 @@ struct TestInlinerInterface : public DialectInlinerInterface {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TestDialect::initialize() {
|
||||
registerAttributes();
|
||||
registerTypes();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "TestOps.cpp.inc"
|
||||
>();
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "TestAttrDefs.cpp.inc"
|
||||
>();
|
||||
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
|
||||
TestInlinerInterface>();
|
||||
addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "TestTypeDefs.cpp.inc"
|
||||
>();
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ def Test_Dialect : Dialect {
|
|||
let dependentDialects = ["::mlir::DLTIDialect"];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void registerAttributes();
|
||||
void registerTypes();
|
||||
|
||||
Attribute parseAttribute(DialectAsmParser &parser,
|
||||
Type type) const override;
|
||||
void printAttribute(Attribute attr,
|
||||
|
|
|
@ -164,6 +164,13 @@ unsigned TestTypeWithLayout::extractKind(DataLayoutEntryListRef params,
|
|||
// TestDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TestDialect::registerTypes() {
|
||||
addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "TestTypeDefs.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
|
||||
llvm::SetVector<Type> &stack) {
|
||||
StringRef typeTag;
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_unittest(MLIRSupportTests
|
|||
DebugCounterTest.cpp
|
||||
IndentedOstreamTest.cpp
|
||||
MathExtrasTest.cpp
|
||||
StorageUniquerTest.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRSupportTests
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
//===- StorageUniquerTest.cpp - StorageUniquer Tests ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Simple storage class used for testing.
|
||||
template <typename ConcreteT, typename... Args>
|
||||
struct SimpleStorage : public StorageUniquer::BaseStorage {
|
||||
using Base = SimpleStorage<ConcreteT, Args...>;
|
||||
using KeyTy = std::tuple<Args...>;
|
||||
|
||||
SimpleStorage(KeyTy key) : key(key) {}
|
||||
|
||||
/// Get an instance of this storage instance.
|
||||
template <typename... ParamsT>
|
||||
static ConcreteT *get(StorageUniquer &uniquer, ParamsT &&...params) {
|
||||
return uniquer.get<ConcreteT>(
|
||||
/*initFn=*/{}, std::make_tuple(std::forward<ParamsT>(params)...));
|
||||
}
|
||||
|
||||
/// Construct an instance with the given storage allocator.
|
||||
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
|
||||
KeyTy key) {
|
||||
return new (alloc.allocate<ConcreteT>())
|
||||
ConcreteT(std::forward<KeyTy>(key));
|
||||
}
|
||||
bool operator==(const KeyTy &key) const { return this->key == key; }
|
||||
|
||||
KeyTy key;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(StorageUniquerTest, NonTrivialDestructor) {
|
||||
struct NonTrivialStorage : public SimpleStorage<NonTrivialStorage, bool *> {
|
||||
using Base::Base;
|
||||
~NonTrivialStorage() {
|
||||
bool *wasDestructed = std::get<0>(key);
|
||||
*wasDestructed = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Verify that the storage instance destructor was properly called.
|
||||
bool wasDestructed = false;
|
||||
{
|
||||
StorageUniquer uniquer;
|
||||
uniquer.registerParametricStorageType<NonTrivialStorage>();
|
||||
NonTrivialStorage::get(uniquer, &wasDestructed);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(wasDestructed);
|
||||
}
|
Loading…
Reference in New Issue