Revert "Separate the Registration from Loading dialects in the Context"

This reverts commit e1de2b7550.
Broke a build bot.
This commit is contained in:
Mehdi Amini 2020-08-18 22:15:59 +00:00
parent 4cbceb74bb
commit d84fe55e0d
93 changed files with 232 additions and 760 deletions

View File

@ -24,16 +24,9 @@
int main(int argc, char **argv) {
mlir::registerAllDialects();
mlir::registerAllPasses();
mlir::registerDialect<mlir::standalone::StandaloneDialect>();
// TODO: Register standalone passes here.
mlir::DialectRegistry registry;
mlir::registerDialect<mlir::standalone::StandaloneDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
// Add the following to include *all* MLIR Core dialects, or selectively
// include what you need like above. You only need to register dialects that
// will be *parsed* by the tool, not the one generated
// registerAllDialects(registry);
return failed(
mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n", registry));
return failed(mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n"));
}

View File

@ -68,9 +68,10 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
}
int dumpMLIR() {
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&

View File

@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);

View File

@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);

View File

@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);

View File

@ -255,9 +255,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -159,9 +159,6 @@ private:
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
}
void runOnOperation() final;
};
} // end anonymous namespace

View File

@ -255,10 +255,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;

View File

@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
} // end anonymous namespace.

View File

@ -159,9 +159,6 @@ private:
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
}
void runOnOperation() final;
};
} // end anonymous namespace

View File

@ -256,10 +256,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Register our Dialect with MLIR.
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;

View File

@ -88,12 +88,6 @@ MlirContext mlirContextCreate();
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
/** Load all the globally registered dialects in the provided context.
* TODO: remove the concept of globally registered dialect by exposing the
* DialectRegistry.
*/
void mlirContextLoadAllDialects(MlirContext context);
/*============================================================================*/
/* Location API. */
/*============================================================================*/

View File

@ -66,11 +66,6 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
`affine.apply`.
}];
let constructor = "mlir::createLowerAffinePass()";
let dependentDialects = [
"scf::SCFDialect",
"StandardOpsDialect",
"vector::VectorDialect"
];
}
//===----------------------------------------------------------------------===//
@ -81,7 +76,6 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the avx512 dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertAVX512ToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
}
//===----------------------------------------------------------------------===//
@ -104,7 +98,6 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
let summary = "Generate NVVM operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
let dependentDialects = ["NVVM::NVVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@ -119,7 +112,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let summary = "Generate ROCDL operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
let dependentDialects = ["ROCDL::ROCDLDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@ -134,7 +126,6 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
let summary = "Convert GPU dialect to SPIR-V dialect";
let constructor = "mlir::createConvertGPUToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@ -145,7 +136,6 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
: Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> {
let summary = "Convert gpu.launch_func to vulkanLaunch external call";
let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertVulkanLaunchFuncToVulkanCalls
@ -153,7 +143,6 @@ def ConvertVulkanLaunchFuncToVulkanCalls
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
"calls";
let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@ -164,7 +153,6 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertLinalgToLLVMPass()";
let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@ -175,7 +163,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the "
"Standard dialect";
let constructor = "mlir::createConvertLinalgToStandardPass()";
let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@ -185,7 +172,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
let summary = "Convert Linalg ops to SPIR-V ops";
let constructor = "mlir::createLinalgToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@ -196,7 +182,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@ -206,7 +191,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
let summary = "Convert top-level AffineFor Ops to GPU kernels";
let constructor = "mlir::createAffineForToGPUPass()";
let dependentDialects = ["gpu::GPUDialect"];
let options = [
Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u",
"Number of GPU block dimensions for mapping">,
@ -218,7 +202,6 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let summary = "Convert mapped scf.parallel ops to gpu launch operations";
let constructor = "mlir::createParallelLoopToGpuPass()";
let dependentDialects = ["AffineDialect", "gpu::GPUDialect"];
}
//===----------------------------------------------------------------------===//
@ -229,7 +212,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let summary = "Convert operations from the shape dialect into the standard "
"dialect";
let constructor = "mlir::createConvertShapeToStandardPass()";
let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@ -239,7 +221,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
let summary = "Convert operations from the shape dialect to the SCF dialect";
let constructor = "mlir::createConvertShapeToSCFPass()";
let dependentDialects = ["scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
@ -249,7 +230,6 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
let summary = "Convert SPIR-V dialect to LLVM dialect";
let constructor = "mlir::createConvertSPIRVToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@ -284,7 +264,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
LLVM IR types.
}];
let constructor = "mlir::createLowerToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
"Use aligned_alloc in place of malloc for heap allocations">,
@ -312,13 +291,11 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
let summary = "Legalize standard ops for SPIR-V lowering";
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
let summary = "Convert Standard Ops to SPIR-V dialect";
let constructor = "mlir::createConvertStandardToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@ -329,7 +306,6 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
let summary = "Lower the operations from the vector dialect into the SCF "
"dialect";
let constructor = "mlir::createConvertVectorToSCFPass()";
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
let options = [
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
"Perform full unrolling when converting vector transfers to SCF">,
@ -344,7 +320,6 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertVectorToLLVMPass()";
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
@ -360,7 +335,6 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the ROCDL "
"dialect";
let constructor = "mlir::createConvertVectorToROCDLPass()";
let dependentDialects = ["ROCDL::ROCDLDialect"];
}
#endif // MLIR_CONVERSION_PASSES

View File

@ -94,7 +94,6 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> {
def AffineVectorize : FunctionPass<"affine-super-vectorize"> {
let summary = "Vectorize to a target independent n-D vector abstraction";
let constructor = "mlir::createSuperVectorizePass()";
let dependentDialects = ["vector::VectorDialect"];
let options = [
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
"Specify an n-D virtual vector size for vectorization",

View File

@ -15,7 +15,6 @@
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -19,11 +19,6 @@ include "mlir/IR/OpBase.td"
def LLVM_Dialect : Dialect {
let name = "llvm";
let cppNamespace = "LLVM";
/// FIXME: at the moment this is a dependency of the translation to LLVM IR,
/// not really one of this dialect per-se.
let dependentDialects = ["omp::OpenMPDialect"];
let hasRegionArgAttrVerify = 1;
let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{

View File

@ -14,7 +14,6 @@
#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def NVVM_Dialect : Dialect {
let name = "nvvm";
let cppNamespace = "NVVM";
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//

View File

@ -22,7 +22,6 @@
#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ROCDL_Dialect : Dialect {
let name = "rocdl";
let cppNamespace = "ROCDL";
let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//

View File

@ -30,20 +30,17 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> {
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
}
def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
@ -57,7 +54,6 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@ -74,9 +70,6 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
def LinalgTiling : FunctionPass<"linalg-tile"> {
let summary = "Tile operations in the linalg dialect";
let constructor = "mlir::createLinalgTilingPass()";
let dependentDialects = [
"AffineDialect", "scf::SCFDialect"
];
let options = [
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
"Test generation of dynamic promoted buffers",
@ -93,7 +86,6 @@ def LinalgTilingToParallelLoops
"Test generation of dynamic promoted buffers",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
#endif // MLIR_DIALECT_LINALG_PASSES

View File

@ -36,7 +36,6 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
"Factors to tile parallel loops by",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = ["AffineDialect"];
}
#endif // MLIR_DIALECT_SCF_PASSES

View File

@ -16,8 +16,6 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/TypeID.h"
#include <map>
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
@ -25,7 +23,7 @@ class DialectInterface;
class OpBuilder;
class Type;
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
@ -214,87 +212,30 @@ private:
/// A collection of registered dialect interfaces.
DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
/// Registers a specific dialect creation function with the global registry.
/// Used through the registerDialect template.
/// Registrations are deduplicated by dialect TypeID and only the first
/// registration will be used.
static void
registerDialectAllocator(TypeID typeID,
const DialectAllocatorFunction &function);
template <typename ConcreteDialect>
friend void registerDialect();
friend class MLIRContext;
};
/// The DialectRegistry maps a dialect namespace to a constructor for the
/// matching dialect.
/// This allows for decoupling the list of dialects "available" from the
/// dialects loaded in the Context. The parser in particular will lazily load
/// dialects in in the Context as operations are encountered.
class DialectRegistry {
using MapTy =
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
public:
template <typename ConcreteDialect>
void insert() {
insert(TypeID::get<ConcreteDialect>(),
ConcreteDialect::getDialectNamespace(),
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
// Just allocate the dialect, the context
// takes ownership of it.
return ctx->getOrLoadDialect<ConcreteDialect>();
})));
}
template <typename ConcreteDialect, typename OtherDialect,
typename... MoreDialects>
void insert() {
insert<ConcreteDialect>();
insert<OtherDialect, MoreDialects...>();
}
/// Add a new dialect constructor to the registry.
void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
/// Load a dialect for this namespace in the provided context.
Dialect *loadByName(StringRef name, MLIRContext *context);
// Register all dialects available in the current registry with the registry
// in the provided context.
void appendTo(DialectRegistry &destination) {
for (const auto &nameAndRegistrationIt : registry)
destination.insert(nameAndRegistrationIt.second.first,
nameAndRegistrationIt.first,
nameAndRegistrationIt.second.second);
}
// Load all dialects available in the registry in the provided context.
void loadAll(MLIRContext *context) {
for (const auto &nameAndRegistrationIt : registry)
nameAndRegistrationIt.second.second(context);
}
MapTy::const_iterator begin() const { return registry.begin(); }
MapTy::const_iterator end() const { return registry.end(); }
private:
MapTy registry;
};
/// Deprecated: this provides a global registry for convenience, while we're
/// transitionning the registration mechanism to a stateless approach.
DialectRegistry &getGlobalDialectRegistry();
/// Registers all dialects from the global registries with the
/// specified MLIRContext. This won't load the dialects in the context,
/// but only make them available for lazy loading by name.
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
/// Note: This method is not thread-safe.
void registerAllDialects(MLIRContext *context);
/// Register and return the dialect with the given namespace in the provided
/// context. Returns nullptr is there is no constructor registered for this
/// dialect.
inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
return getGlobalDialectRegistry().loadByName(name, context);
}
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
getGlobalDialectRegistry().insert<ConcreteDialect>();
Dialect::registerDialectAllocator(
TypeID::get<ConcreteDialect>(),
[](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
}
/// DialectRegistration provides a global initializer that registers a Dialect

View File

@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("arguments may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i,
attr)))

View File

@ -19,12 +19,10 @@ namespace mlir {
class AbstractOperation;
class DiagnosticEngine;
class Dialect;
class DialectRegistry;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class StorageUniquer;
DialectRegistry &getGlobalDialectRegistry();
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
@ -36,69 +34,34 @@ DialectRegistry &getGlobalDialectRegistry();
///
class MLIRContext {
public:
/// Create a new Context.
/// The loadAllDialects parameters allows to load all dialects from the global
/// registry on Context construction. It is deprecated and will be removed
/// soon.
explicit MLIRContext(bool loadAllDialects = true);
explicit MLIRContext();
~MLIRContext();
/// Return information about all IR dialects loaded in the context.
std::vector<Dialect *> getLoadedDialects();
/// Return the dialect registry associated with this context.
DialectRegistry &getDialectRegistry();
/// Return information about all available dialects in the registry in this
/// context.
std::vector<StringRef> getAvailableDialects();
/// Return information about all registered IR dialects.
std::vector<Dialect *> getRegisteredDialects();
/// Get a registered IR dialect with the given namespace. If an exact match is
/// not found, then return nullptr.
Dialect *getLoadedDialect(StringRef name);
Dialect *getRegisteredDialect(StringRef name);
/// Get a registered IR dialect for the given derived dialect type. The
/// derived type must provide a static 'getDialectNamespace' method.
template <typename T>
T *getLoadedDialect() {
return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
template <typename T> T *getRegisteredDialect() {
return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
}
/// Get (or create) a dialect for the given derived dialect type. The derived
/// type must provide a static 'getDialectNamespace' method.
template <typename T>
T *getOrLoadDialect() {
return static_cast<T *>(
getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
T *getOrCreateDialect() {
return static_cast<T *>(getOrCreateDialect(
T::getDialectNamespace(), TypeID::get<T>(), [this]() {
std::unique_ptr<T> dialect(new T(this));
dialect->dialectID = TypeID::get<T>();
return dialect;
}));
}
/// Load a dialect in the context.
template <typename Dialect>
void loadDialect() {
getOrLoadDialect<Dialect>();
}
/// Load a list dialects in the context.
template <typename Dialect, typename OtherDialect, typename... MoreDialects>
void loadDialect() {
getOrLoadDialect<Dialect>();
loadDialect<OtherDialect, MoreDialects...>();
}
/// Deprecated: load all globally registered dialects into this context.
/// This method will be removed soon, it can be used temporarily as we're
/// phasing out the global registry.
void loadAllGloballyRegisteredDialects();
/// Get (or create) a dialect for the given derived dialect name.
/// The dialect will be loaded from the registry if no dialect is found.
/// If no dialect is loaded for this name and none is available in the
/// registry, returns nullptr.
Dialect *getOrLoadDialect(StringRef name);
/// Return true if we allow to create operation for unregistered dialects.
bool allowsUnregisteredDialects();
@ -160,12 +123,10 @@ private:
const std::unique_ptr<MLIRContextImpl> impl;
/// Get a dialect for the provided namespace and TypeID: abort the program if
/// a dialect exist for this namespace with different TypeID. If a dialect has
/// not been loaded for this namespace/TypeID yet, use the provided ctor to
/// create one on the fly and load it. Returns a pointer to the dialect owned
/// by the context.
Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor);
/// a dialect exist for this namespace with different TypeID. Returns a
/// pointer to the dialect owned by the context.
Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor);
MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;

View File

@ -244,11 +244,6 @@ class Dialect {
// The description of the dialect.
string description = ?;
// A list of dialects this dialect will load on construction as dependencies.
// These are dialects that this dialect may involved in canonicalization
// pattern or interfaces.
list<string> dependentDialects = [];
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid

View File

@ -35,35 +35,30 @@
namespace mlir {
// Add all the MLIR dialects to the provided registry.
inline void registerAllDialects(DialectRegistry &registry) {
// clang-format off
registry.insert<acc::OpenACCDialect,
AffineDialect,
avx512::AVX512Dialect,
gpu::GPUDialect,
LLVM::LLVMAVX512Dialect,
LLVM::LLVMDialect,
linalg::LinalgDialect,
scf::SCFDialect,
omp::OpenMPDialect,
quant::QuantizationDialect,
spirv::SPIRVDialect,
StandardOpsDialect,
vector::VectorDialect,
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,
SDBMDialect,
shape::ShapeDialect>();
// clang-format on
}
// This function should be called before creating any MLIRContext if one expect
// all the possible dialects to be made available to the context automatically.
inline void registerAllDialects() {
static bool initOnce =
([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
(void)initOnce;
static bool init_once = []() {
registerDialect<acc::OpenACCDialect>();
registerDialect<AffineDialect>();
registerDialect<avx512::AVX512Dialect>();
registerDialect<gpu::GPUDialect>();
registerDialect<LLVM::LLVMAVX512Dialect>();
registerDialect<LLVM::LLVMDialect>();
registerDialect<linalg::LinalgDialect>();
registerDialect<scf::SCFDialect>();
registerDialect<omp::OpenMPDialect>();
registerDialect<quant::QuantizationDialect>();
registerDialect<spirv::SPIRVDialect>();
registerDialect<StandardOpsDialect>();
registerDialect<vector::VectorDialect>();
registerDialect<NVVM::NVVMDialect>();
registerDialect<ROCDL::ROCDLDialect>();
registerDialect<SDBMDialect>();
registerDialect<shape::ShapeDialect>();
return true;
}();
(void)init_once;
}
} // namespace mlir

View File

@ -28,7 +28,7 @@ void registerAVX512ToLLVMIRTranslation();
// expects all the possible translations to be made available to the context
// automatically.
inline void registerAllTranslations() {
static bool initOnce = []() {
static bool init_once = []() {
registerFromLLVMIRTranslation();
registerFromSPIRVTranslation();
registerToLLVMIRTranslation();
@ -38,7 +38,7 @@ inline void registerAllTranslations() {
registerAVX512ToLLVMIRTranslation();
return true;
}();
(void)initOnce;
(void)init_once;
}
} // namespace mlir

View File

@ -9,7 +9,6 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
@ -58,13 +57,6 @@ public:
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
/// Register dependent dialects for the current pass.
/// A pass is expected to register the dialects it will create entities for
/// (Operations, Types, Attributes), other than dialect that exists in the
/// input. For example, a pass that converts from Linalg to Affine would
/// register the Affine dialect but does not need to register Linalg.
virtual void getDependentDialects(DialectRegistry &registry) const {}
/// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const {

View File

@ -78,9 +78,6 @@ class PassBase<string passArg, string base> {
// A C++ constructor call to create an instance of this pass.
code constructor = [{}];
// A list of dialects this pass may produce entities in.
list<string> dependentDialects = [];
// A set of options provided by this pass.
list<Option> options = [];

View File

@ -9,7 +9,6 @@
#ifndef MLIR_PASS_PASSMANAGER_H
#define MLIR_PASS_PASSMANAGER_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Optional.h"
@ -59,14 +58,6 @@ public:
pass_iterator end();
iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
using const_pass_iterator = llvm::pointee_iterator<
std::vector<std::unique_ptr<Pass>>::const_iterator>;
const_pass_iterator begin() const;
const_pass_iterator end() const;
iterator_range<const_pass_iterator> getPasses() const {
return {begin(), end()};
}
/// Run the held passes over the given operation.
LogicalResult run(Operation *op, AnalysisManager am);
@ -109,11 +100,6 @@ public:
/// Merge the pass statistics of this class into 'other'.
void mergeStatisticsInto(OpPassManager &other);
/// Register dependent dialects for the current pass manager.
/// This is forwarding to every pass in this PassManager, see the
/// documentation for the same method on the Pass class.
void getDependentDialects(DialectRegistry &dialects) const;
private:
OpPassManager(OperationName name, bool verifyPasses);

View File

@ -21,14 +21,12 @@ class MemoryBuffer;
} // end namespace llvm
namespace mlir {
class DialectRegistry;
class PassPipelineCLParser;
/// Perform the core processing behind `mlir-opt`:
/// - outputStream is the stream where the resulting IR is printed.
/// - buffer is the in-memory file to parser and process.
/// - passPipeline is the specification of the pipeline that will be applied.
/// - registry should contain all the dialects that can be parsed in the source.
/// - splitInputFile will look for a "-----" marker in the input file, and load
/// each chunk in an individual ModuleOp processed separately.
/// - verifyDiagnostics enables a verification mode where comments starting with
@ -37,25 +35,13 @@ class PassPipelineCLParser;
/// - verifyPasses enables the IR verifier in-between each pass in the pipeline.
/// - allowUnregisteredDialects allows to parse and create operation without
/// registering the Dialect in the MLIRContext.
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr<llvm::MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext = true);
bool splitInputFile, bool verifyDiagnostics,
bool verifyPasses, bool allowUnregisteredDialects);
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
/// - registry should contain all the dialects that can be parsed in the source.
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry &registry,
bool preloadDialectsInContext = true);
LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName);
} // end namespace mlir

View File

@ -14,7 +14,6 @@
#include "mlir/Support/LLVM.h"
#include <string>
#include <vector>
namespace llvm {
class Record;
@ -26,7 +25,7 @@ namespace tblgen {
// and provides helper methods for accessing them.
class Dialect {
public:
explicit Dialect(const llvm::Record *def);
explicit Dialect(const llvm::Record *def) : def(def) {}
// Returns the name of this dialect.
StringRef getName() const;
@ -44,10 +43,6 @@ public:
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
// Returns the list of dialect (class names) that this dialect depends on.
// These are dialects that will be loaded on construction of this dialect.
ArrayRef<StringRef> getDependentDialects() const;
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
@ -75,7 +70,6 @@ public:
private:
const llvm::Record *def;
std::vector<StringRef> dependentDialects;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -94,9 +94,6 @@ public:
/// Return the C++ constructor call to create an instance of this pass.
StringRef getConstructor() const;
/// Return the dialects this pass needs to be registered.
ArrayRef<StringRef> getDependentDialects() const;
/// Return the options provided by this pass.
ArrayRef<PassOption> getOptions() const;
@ -107,7 +104,6 @@ public:
private:
const llvm::Record *def;
std::vector<StringRef> dependentDialects;
std::vector<PassOption> options;
std::vector<PassStatistic> statistics;
};

View File

@ -162,8 +162,6 @@ def BufferPlacement : FunctionPass<"buffer-placement"> {
}];
let constructor = "mlir::createBufferPlacementPass()";
// TODO: this pass likely shouldn't depend on Linalg?
let dependentDialects = ["linalg::LinalgDialect"];
}
def Canonicalizer : Pass<"canonicalize"> {

View File

@ -10,11 +10,9 @@
#include "mlir/CAPI/IR.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "llvm/Support/raw_ostream.h"
@ -52,17 +50,12 @@ private:
/* ========================================================================== */
MlirContext mlirContextCreate() {
auto *context = new MLIRContext(/*loadAllDialects=*/false);
auto *context = new MLIRContext;
return wrap(context);
}
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
void mlirContextLoadAllDialects(MlirContext context) {
registerAllDialects(unwrap(context));
getGlobalDialectRegistry().loadAll(unwrap(context));
}
/* ========================================================================== */
/* Location API. */
/* ========================================================================== */

View File

@ -16,7 +16,6 @@
#include "../PassDetail.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"

View File

@ -19,7 +19,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"

View File

@ -12,43 +12,11 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
class AffineDialect;
class StandardOpsDialect;
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace gpu {
class GPUDialect;
class GPUModuleOp;
} // end namespace gpu
namespace LLVM {
class LLVMDialect;
class LLVMAVX512Dialect;
} // end namespace LLVM
namespace NVVM {
class NVVMDialect;
} // end namespace NVVM
namespace ROCDL {
class ROCDLDialect;
} // end namespace ROCDL
namespace scf {
class SCFDialect;
} // end namespace scf
namespace spirv {
class SPIRVDialect;
} // end namespace spirv
namespace vector {
class VectorDialect;
} // end namespace vector
#define GEN_PASS_CLASSES
#include "mlir/Conversion/Passes.h.inc"

View File

@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options)
: llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)

View File

@ -14,7 +14,6 @@
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"

View File

@ -12,16 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace linalg {
class LinalgDialect;
} // end namespace linalg
namespace vector {
class VectorDialect;
} // end namespace vector
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Affine/Passes.h.inc"

View File

@ -1244,7 +1244,6 @@ template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
result.getContext()->getOrLoadDialect<StandardOpsDialect>();
// Optional attributes may be added.
if (parser.parseOperandList(operandsInfo) ||

View File

@ -9,18 +9,9 @@
#ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace scf {
class SCFDialect;
} // end namespace scf
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Linalg/Passes.h.inc"

View File

@ -12,11 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
class AffineDialect;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/SCF/Passes.h.inc"

View File

@ -517,7 +517,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
SDBMDialect *dialect;
} converter;
converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
if (auto result = converter.visit(affine))
return result;

View File

@ -259,9 +259,7 @@ int mlir::JitRunnerMain(
}
}
MLIRContext context(/*loadAllDialects=*/false);
registerAllDialects(&context);
MLIRContext context;
auto m = parseMLIRInput(options.inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";

View File

@ -27,29 +27,21 @@ DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
/// Registry for all dialect allocation functions.
static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
dialectRegistry;
void Dialect::registerDialectAllocator(
TypeID typeID, const DialectAllocatorFunction &function) {
assert(function &&
"Attempting to register an empty dialect initialize function");
dialectRegistry->insert({typeID, function});
}
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
dialectRegistry->appendTo(context->getDialectRegistry());
}
Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
auto it = registry.find(name.str());
if (it == registry.end())
return nullptr;
return it->second.second(context);
}
void DialectRegistry::insert(TypeID typeID, StringRef name,
DialectAllocatorFunction ctor) {
auto inserted =
registry.insert(std::make_pair(name, std::make_pair(typeID, ctor)));
if (!inserted.second && inserted.first->second.first != typeID) {
llvm::report_fatal_error(
"Trying to register different dialects for the same namespace: " +
name);
}
for (const auto &it : *dialectRegistry)
it.second(context);
}
//===----------------------------------------------------------------------===//
@ -127,7 +119,7 @@ DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind) {
for (auto *dialect : ctx->getLoadedDialects()) {
for (auto *dialect : ctx->getRegisteredDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
orderedInterfaces.push_back(interface);

View File

@ -31,13 +31,10 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#define DEBUG_TYPE "mlircontext"
using namespace mlir;
using namespace mlir::detail;
@ -278,8 +275,7 @@ public:
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects.
DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
DialectRegistry dialectsRegistry;
std::vector<std::unique_ptr<Dialect>> dialects;
/// This is a mapping from operation name to AbstractOperation for registered
/// operations.
@ -350,7 +346,7 @@ public:
};
} // end namespace mlir
MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
// Initialize values based on the command line flags if they were provided.
if (clOptions.isConstructed()) {
disableMultithreading(clOptions->disableThreading);
@ -359,9 +355,8 @@ MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
}
// Register dialects with this context.
getOrLoadDialect<BuiltinDialect>();
if (loadAllDialects)
loadAllGloballyRegisteredDialects();
getOrCreateDialect<BuiltinDialect>();
registerAllDialects(this);
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.
@ -443,72 +438,54 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
DialectRegistry &MLIRContext::getDialectRegistry() {
return impl->dialectsRegistry;
}
/// Return information about all registered IR dialects.
std::vector<Dialect *> MLIRContext::getLoadedDialects() {
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
std::vector<Dialect *> result;
result.reserve(impl->loadedDialects.size());
for (auto &dialect : impl->loadedDialects)
result.push_back(dialect.second.get());
llvm::array_pod_sort(result.begin(), result.end(),
[](Dialect *const *lhs, Dialect *const *rhs) -> int {
return (*lhs)->getNamespace() < (*rhs)->getNamespace();
});
return result;
}
std::vector<StringRef> MLIRContext::getAvailableDialects() {
std::vector<StringRef> result;
for (auto &dialect : impl->dialectsRegistry)
result.push_back(dialect.first);
result.reserve(impl->dialects.size());
for (auto &dialect : impl->dialects)
result.push_back(dialect.get());
return result;
}
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
Dialect *MLIRContext::getLoadedDialect(StringRef name) {
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
// Dialects are sorted by name, so we can use binary search for lookup.
auto it = impl->loadedDialects.find(name);
return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
}
Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
Dialect *dialect = getLoadedDialect(name);
if (dialect)
return dialect;
return impl->dialectsRegistry.loadByName(name, this);
auto it = llvm::lower_bound(
impl->dialects, name,
[](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
return (it != impl->dialects.end() && (*it)->getNamespace() == name)
? (*it).get()
: nullptr;
}
/// Get a dialect for the provided namespace and TypeID: abort the program if a
/// dialect exist for this namespace with different TypeID. Returns a pointer to
/// the dialect owned by the context.
Dialect *
MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor) {
MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
if (!dialect) {
LLVM_DEBUG(llvm::dbgs()
<< "Load new dialect in Context" << dialectNamespace);
dialect = ctor();
assert(dialect && "dialect ctor failed");
return dialect.get();
}
auto insertPt =
llvm::lower_bound(impl.dialects, nullptr,
[&](const std::unique_ptr<Dialect> &lhs,
const std::unique_ptr<Dialect> &rhs) {
if (!lhs)
return dialectNamespace < rhs->getNamespace();
return lhs->getNamespace() < dialectNamespace;
});
// Abort if dialect with namespace has already been registered.
if (dialect->getTypeID() != dialectID)
if (insertPt != impl.dialects.end() &&
(*insertPt)->getNamespace() == dialectNamespace) {
if ((*insertPt)->getTypeID() == dialectID)
return insertPt->get();
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
return dialect.get();
}
void MLIRContext::loadAllGloballyRegisteredDialects() {
getGlobalDialectRegistry().loadAll(this);
}
auto it = impl.dialects.insert(insertPt, ctor());
return &**it;
}
bool MLIRContext::allowsUnregisteredDialects() {

View File

@ -214,7 +214,7 @@ Dialect *Operation::getDialect() {
// If this operation hasn't been registered or doesn't have abstract
// operation, try looking up the dialect name in the context.
return getContext()->getLoadedDialect(getName().getDialect());
return getContext()->getRegisteredDialect(getName().getDialect());
}
Region *Operation::getParentRegion() {

View File

@ -50,7 +50,7 @@ public:
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
assert(attr.first.strref().contains('.') && "expected dialect attribute");
auto dialectNamePair = attr.first.strref().split('.');
return ctx->getLoadedDialect(dialectNamePair.first);
return ctx->getRegisteredDialect(dialectNamePair.first);
}
private:
@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
if (it == dialectAllowsUnknownOps.end()) {
// If the operation dialect is registered, query it directly.
if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
it = dialectAllowsUnknownOps
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
.first;

View File

@ -12,7 +12,6 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringExtras.h"
@ -247,11 +246,6 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("duplicate key in dictionary attribute");
consumeToken();
// Lazy load a dialect in the context if there is a possible namespace.
auto splitName = nameId->strref().split('.');
if (!splitName.second.empty())
getContext()->getOrLoadDialect(splitName.first);
// Try to parse the '=' for the attribute value.
if (!consumeIf(Token::equal)) {
// If there is no '=', we treat this as a unit attribute.
@ -823,9 +817,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
// Lazy load a dialect in the context if there is a possible namespace.
Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
auto *dialect = builder.getContext()->getRegisteredDialect(name);
// TODO: Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.

View File

@ -526,8 +526,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
return Attribute();
// If we found a registered dialect, then ask it to parse the attribute.
if (Dialect *dialect =
builder.getContext()->getOrLoadDialect(dialectName)) {
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Attribute>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
@ -564,9 +563,7 @@ Type Parser::parseExtendedType() {
[&](StringRef dialectName, StringRef symbolData,
llvm::SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
auto *dialect = state.context->getOrLoadDialect(dialectName);
if (dialect) {
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Type>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);

View File

@ -12,7 +12,6 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
@ -728,7 +727,7 @@ Operation *OperationParser::parseGenericOperation() {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
std::string name = getToken().getStringValue();
auto name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
if (name.find('\0') != StringRef::npos)
@ -738,15 +737,6 @@ Operation *OperationParser::parseGenericOperation() {
OperationState result(srcLocation, name);
// Lazy load dialects in the context as needed.
if (!result.name.getAbstractOperation()) {
StringRef dialectName = StringRef(name).split('.').first;
if (!getContext()->getLoadedDialect(dialectName) &&
getContext()->getOrLoadDialect(dialectName)) {
result.name = OperationName(name, getContext());
}
}
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@ -1452,28 +1442,17 @@ private:
Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
llvm::SMLoc opLoc = getToken().getLoc();
StringRef opName = getTokenSpelling();
auto opLoc = getToken().getLoc();
auto opName = getTokenSpelling();
auto *opDefinition = AbstractOperation::lookup(opName, getContext());
if (!opDefinition) {
if (opName.contains('.')) {
// This op has a dialect, we try to check if we can register it in the
// context on the fly.
StringRef dialectName = opName.split('.').first;
if (!getContext()->getLoadedDialect(dialectName) &&
getContext()->getOrLoadDialect(dialectName)) {
opDefinition = AbstractOperation::lookup(opName, getContext());
}
} else {
// If the operation name has no namespace prefix we treat it as a standard
// operation and prefix it with "std".
// TODO: Would it be better to just build a mapping of the registered
// operations in the standard dialect?
if (getContext()->getOrLoadDialect("std"))
opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
getContext());
}
if (!opDefinition && !opName.contains('.')) {
// If the operation name has no namespace prefix we treat it as a standard
// operation and prefix it with "std".
// TODO: Would it be better to just build a mapping of the registered
// operations in the standard dialect?
opDefinition =
AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
}
if (!opDefinition) {

View File

@ -290,13 +290,6 @@ OpPassManager::pass_iterator OpPassManager::begin() {
}
OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); }
OpPassManager::const_pass_iterator OpPassManager::begin() const {
return impl->passes.begin();
}
OpPassManager::const_pass_iterator OpPassManager::end() const {
return impl->passes.end();
}
/// Run all of the passes in this manager over the current operation.
LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
// Run each of the held passes.
@ -353,16 +346,6 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
::printAsTextualPipeline(impl->passes, os);
}
static void registerDialectsForPipeline(const OpPassManager &pm,
DialectRegistry &dialects) {
for (const Pass &pass : pm.getPasses())
pass.getDependentDialects(dialects);
}
void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
registerDialectsForPipeline(*this, dialects);
}
//===----------------------------------------------------------------------===//
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
@ -395,11 +378,6 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
for (auto &pm : mgrs)
pm.getDependentDialects(dialects);
}
/// Merge the current pass adaptor into given 'rhs'.
void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
for (auto &pm : mgrs) {
@ -743,11 +721,6 @@ LogicalResult PassManager::run(ModuleOp module) {
// pipeline.
getImpl().coalesceAdjacentAdaptorPasses();
// Register all dialects for the current pipeline.
DialectRegistry dependentDialects;
getDependentDialects(dependentDialects);
dependentDialects.loadAll(module.getContext());
// Construct an analysis manager for the pipeline.
ModuleAnalysisManager am(module, instrumentor.get());

View File

@ -43,10 +43,6 @@ public:
/// Returns the pass managers held by this adaptor.
MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
/// Populate the set of dependent dialects for the passes in the current
/// adaptor.
void getDependentDialects(DialectRegistry &dialects) const override;
/// Return the async pass managers held by this parallel adaptor.
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
return asyncExecutors;

View File

@ -81,18 +81,13 @@ static LogicalResult processBuffer(raw_ostream &os,
std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry) {
const PassPipelineCLParser &passPipeline) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Parse the input file.
MLIRContext context(/*loadAllDialects=*/preloadDialectsInContext);
registry.appendTo(context.getDialectRegistry());
if (preloadDialectsInContext)
registry.loadAll(&context);
MLIRContext context;
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
@ -120,10 +115,9 @@ static LogicalResult processBuffer(raw_ostream &os,
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry &registry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext) {
bool splitInputFile, bool verifyDiagnostics,
bool verifyPasses,
bool allowUnregisteredDialects) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
if (splitInputFile)
@ -132,19 +126,15 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline,
registry);
passPipeline);
},
outputStream);
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passPipeline, registry);
verifyPasses, allowUnregisteredDialects, passPipeline);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry &registry,
bool preloadDialectsInContext) {
LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) {
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));
@ -190,19 +180,25 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
{
llvm::raw_string_ostream os(helpHeader);
MLIRContext context;
interleaveComma(registry, os, [&](auto &registryEntry) {
StringRef name = registryEntry.first;
os << name;
interleaveComma(context.getRegisteredDialects(), os, [&](Dialect *dialect) {
StringRef name = dialect->getNamespace();
// filter the builtin dialect.
if (name.empty())
os << "<builtin>";
else
os << name;
});
}
// Parse pass names in main to ensure static initialization completed.
cl::ParseCommandLineOptions(argc, argv, helpHeader);
if (showDialects) {
llvm::outs() << "Available Dialects:\n";
llvm::outs() << "Registered Dialects:\n";
MLIRContext context;
interleave(
registry, llvm::outs(),
[](auto &registryEntry) { llvm::outs() << registryEntry.first; }, "\n");
context.getRegisteredDialects(), llvm::outs(),
[](Dialect *dialect) { llvm::outs() << dialect->getNamespace(); },
"\n");
return success();
}
@ -220,9 +216,9 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
return failure();
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext)))
allowUnregisteredDialects)))
return failure();
// Keep the output file if the invocation of MlirOptMain was successful.

View File

@ -15,10 +15,6 @@
using namespace mlir;
using namespace mlir::tblgen;
Dialect::Dialect(const llvm::Record *def) : def(def) {
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
dependentDialects.push_back(dialect);
}
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
@ -50,10 +46,6 @@ StringRef Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
ArrayRef<StringRef> Dialect::getDependentDialects() const {
return dependentDialects;
}
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;

View File

@ -69,8 +69,6 @@ Pass::Pass(const llvm::Record *def) : def(def) {
options.push_back(PassOption(init));
for (auto *init : def->getValueAsListOfDefs("statistics"))
statistics.push_back(PassStatistic(init));
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
dependentDialects.push_back(dialect);
}
StringRef Pass::getArgument() const {
@ -90,9 +88,6 @@ StringRef Pass::getDescription() const {
StringRef Pass::getConstructor() const {
return def->getValueAsString("constructor");
}
ArrayRef<StringRef> Pass::getDependentDialects() const {
return dependentDialects;
}
ArrayRef<PassOption> Pass::getOptions() const { return options; }

View File

@ -836,7 +836,6 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
OwningModuleRef
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context) {
context->loadDialect<LLVMDialect>();
OwningModuleRef module(ModuleOp::create(
FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));

View File

@ -302,7 +302,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
: mlirModule(module), llvmModule(std::move(llvmModule)),
debugTranslation(
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()),
ompDialect(
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
typeTranslator(this->llvmModule->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
@ -943,8 +944,8 @@ ModuleTranslation::lookupValues(ValueRange values) {
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
if (auto dataLayoutAttr =
m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()))
llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());

View File

@ -12,13 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
// Forward declaration from Dialect.h
template <typename ConcreteDialect>
void registerDialect(DialectRegistry &registry);
namespace linalg {
class LinalgDialect;
} // end namespace linalg
#define GEN_PASS_CLASSES
#include "mlir/Transforms/Passes.h.inc"

View File

@ -383,7 +383,6 @@ static int printStandardTypes(MlirContext ctx) {
int main() {
mlirRegisterAllDialects();
MlirContext ctx = mlirContextCreate();
mlirContextLoadAllDialects(ctx);
MlirLocation location = mlirLocationUnknownGet(ctx);
MlirModule moduleOp = makeAdd(ctx, location);

View File

@ -36,18 +36,16 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
static MLIRContext &globalContext() {
static thread_local MLIRContext context(/*loadAllDialects=*/false);
static thread_local bool initOnce = [&]() {
// clang-format off
context.loadDialect<AffineDialect,
scf::SCFDialect,
linalg::LinalgDialect,
StandardOpsDialect,
vector::VectorDialect>();
// clang-format on
static bool init_once = []() {
registerDialect<AffineDialect>();
registerDialect<linalg::LinalgDialect>();
registerDialect<scf::SCFDialect>();
registerDialect<StandardOpsDialect>();
registerDialect<vector::VectorDialect>();
return true;
}();
(void)initOnce;
(void)init_once;
static thread_local MLIRContext context;
context.allowUnregisteredDialects();
return context;
}

View File

@ -19,19 +19,18 @@
using namespace mlir;
// Load the SDBM dialect
static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
static thread_local MLIRContext context(/*loadAllDialects=*/false);
static thread_local bool once =
(context.getOrLoadDialect<SDBMDialect>(), true);
(void)once;
static thread_local MLIRContext context;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
d = ctx()->getOrLoadDialect<SDBMDialect>();
d = ctx()->getRegisteredDialect<SDBMDialect>();
}
return d;
}

View File

@ -14,7 +14,6 @@
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
@ -73,9 +72,6 @@ struct VectorizerTestPass
: public PassWrapper<VectorizerTestPass, FunctionPass> {
static constexpr auto kTestAffineMapOpName = "test_affine_map";
static constexpr auto kTestAffineMapAttrName = "affine_map";
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
void runOnFunction() override;
void testVectorShapeRatio(llvm::raw_ostream &outs);

View File

@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() {
auto f = getFunction();
llvm::outs() << f.getName() << "\n";
Dialect *spvDialect = getContext().getLoadedDialect("spv");
Dialect *spvDialect = getContext().getRegisteredDialect("spv");
f.getOperation()->walk([&](Operation *op) {
if (op->getDialect() != spvDialect)

View File

@ -21,10 +21,6 @@
using namespace mlir;
void mlir::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//

View File

@ -37,8 +37,6 @@ namespace mlir {
#define GET_OP_CLASSES
#include "TestOps.h.inc"
void registerTestDialect(DialectRegistry &registry);
} // end namespace mlir
#endif // MLIR_TESTDIALECT_H

View File

@ -768,10 +768,6 @@ struct TestTypeConversionProducer
struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TestDialect>();
}
void runOnOperation() override {
// Initialize the type converter.
TypeConverter converter;

View File

@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@ -20,9 +19,6 @@ using namespace mlir;
namespace {
struct TestAllReduceLoweringPass
: public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
}
void runOnOperation() override {
OwningRewritePatternList patterns;
populateGpuRewritePatterns(&getContext(), patterns);

View File

@ -116,10 +116,6 @@ struct TestBufferPlacementPreparationPass
patterns->insert<GenericOpConverter>(context, placer, converter);
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext &context = this->getContext();
ConversionTarget target(context);

View File

@ -13,9 +13,6 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/MemoryPromotion.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
@ -29,10 +26,6 @@ namespace {
class TestGpuMemoryPromotionPass
: public PassWrapper<TestGpuMemoryPromotionPass,
OperationPass<gpu::GPUFuncOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect, scf::SCFDialect>();
}
void runOnOperation() override {
gpu::GPUFuncOp op = getOperation();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {

View File

@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Pass/Pass.h"
@ -23,9 +22,6 @@ struct TestLinalgHoisting
: public PassWrapper<TestLinalgHoisting, FunctionPass> {
TestLinalgHoisting() = default;
TestLinalgHoisting(const TestLinalgHoisting &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
void runOnFunction() override;

View File

@ -15,7 +15,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@ -31,16 +30,6 @@ struct TestLinalgTransforms
TestLinalgTransforms() = default;
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
// clang-format off
registry.insert<AffineDialect,
scf::SCFDialect,
StandardOpsDialect,
vector::VectorDialect,
gpu::GPUDialect>();
// clang-format on
}
void runOnFunction() override;
Option<bool> testPatterns{*this, "test-patterns",

View File

@ -8,9 +8,6 @@
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
@ -131,11 +128,6 @@ struct TestVectorTransferFullPartialSplitPatterns
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
}
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt --show-dialects | FileCheck %s
// CHECK: Available Dialects:
// CHECK: Registered Dialects:
// CHECK: affine
// CHECK: gpu
// CHECK: linalg

View File

@ -1703,7 +1703,7 @@ int main(int argc, char **argv) {
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
MLIRContext context(/*loadAllDialects=*/false);
MLIRContext context;
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
Parser parser(mgr, &context);

View File

@ -48,7 +48,6 @@ void registerTestConstantFold();
void registerTestConvertGPUKernelToCubinPass();
void registerTestConvertGPUKernelToHsacoPass();
void registerTestDominancePass();
void registerTestDialect(DialectRegistry &);
void registerTestExpandTanhPass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
@ -131,10 +130,5 @@ int main(int argc, char **argv) {
#ifdef MLIR_INCLUDE_TESTS
registerTestPasses();
#endif
DialectRegistry registry;
registerAllDialects(registry);
registerTestDialect(registry);
return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n",
registry,
/*preloadDialectsInContext=*/false));
return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver"));
}

View File

@ -61,14 +61,11 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
///
/// {0}: The name of the dialect class.
/// {1}: The dialect namespace.
/// {2}: initialization code that is emitted in the ctor body before calling
/// initialize()
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
explicit {0}(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context,
::mlir::TypeID::get<{0}>()) {{
{2}
initialize();
}
void initialize();
@ -77,12 +74,6 @@ public:
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
)";
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
const char *const dialectRegistrationTemplate = R"(
getContext()->getOrLoadDialect<{0}>();
)";
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
/// Parse an attribute registered to this dialect.
@ -145,18 +136,9 @@ static void emitDialectDecl(Dialect &dialect,
iterator_range<DialectFilterIterator> dialectAttrs,
iterator_range<DialectFilterIterator> dialectTypes,
raw_ostream &os) {
/// Build the list of dependent dialects
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : dialect.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
}
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
dependentDialectRegistrations);
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.

View File

@ -36,7 +36,6 @@ static llvm::cl::opt<std::string>
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
/// {3}: The dependent dialects registration.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
@ -64,20 +63,9 @@ public:
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
{3}
}
protected:
)";
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
const char *const dialectRegistrationTemplate = R"(
registry.insert<{0}>();
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
@ -106,15 +94,8 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : pass.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
}
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), dependentDialectRegistrations);
pass.getArgument());
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";

View File

@ -88,8 +88,7 @@ int main(int argc, char **argv) {
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
MLIRContext context(false);
registerAllDialects(&context);
MLIRContext context;
context.allowUnregisteredDialects();
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;

View File

@ -17,6 +17,9 @@
using namespace mlir;
using namespace mlir::quant;
// Load the quant dialect
static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
namespace {
// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
@ -75,8 +78,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
}
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@ -93,8 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@ -118,8 +119,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@ -143,8 +143,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);

View File

@ -38,8 +38,7 @@ using ::testing::StrEq;
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
DeserializationTest() : context(/*loadAllDialects=*/false) {
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
DeserializationTest() {
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {

View File

@ -36,10 +36,7 @@ using namespace mlir;
class SerializationTest : public ::testing::Test {
protected:
SerializationTest() : context(/*loadAllDialects=*/false) {
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
createModuleOp();
}
SerializationTest() { createModuleOp(); }
void createModuleOp() {
OpBuilder builder(&context);

View File

@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context(false);
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
MLIRContext context(false);
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
}
TEST(DenseSplatTest, BoolNonSplat) {
MLIRContext context(false);
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
MLIRContext context(false);
MLIRContext context;
constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context);
APInt value(intWidth, 10);
@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
}
TEST(DenseSplatTest, Int32Splat) {
MLIRContext context(false);
MLIRContext context;
IntegerType intTy = IntegerType::get(32, &context);
int value = 64;
@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
}
TEST(DenseSplatTest, IntAttrSplat) {
MLIRContext context(false);
MLIRContext context;
IntegerType intTy = IntegerType::get(85, &context);
Attribute value = IntegerAttr::get(intTy, 109);
@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
}
TEST(DenseSplatTest, F32Splat) {
MLIRContext context(false);
MLIRContext context;
FloatType floatTy = FloatType::getF32(&context);
float value = 10.0;
@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
}
TEST(DenseSplatTest, F64Splat) {
MLIRContext context(false);
MLIRContext context;
FloatType floatTy = FloatType::getF64(&context);
double value = 10.0;
@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
}
TEST(DenseSplatTest, FloatAttrSplat) {
MLIRContext context(false);
MLIRContext context;
FloatType floatTy = FloatType::getF32(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
}
TEST(DenseSplatTest, BF16Splat) {
MLIRContext context(false);
MLIRContext context;
FloatType floatTy = FloatType::getBF16(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
}
TEST(DenseSplatTest, StringSplat) {
MLIRContext context(false);
MLIRContext context;
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
StringRef value = "test-string";
@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
}
TEST(DenseSplatTest, StringAttrSplat) {
MLIRContext context(false);
MLIRContext context;
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
Attribute stringAttr = StringAttr::get("test-string", stringType);
@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
}
TEST(DenseComplexTest, ComplexFloatSplat) {
MLIRContext context(false);
MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexIntSplat) {
MLIRContext context(false);
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPFloatSplat) {
MLIRContext context(false);
MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPIntSplat) {
MLIRContext context(false);
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);

View File

@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect {
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
MLIRContext context(false);
MLIRContext context;
// Registering a dialect with the same namespace twice should result in a
// failure.
context.loadDialect<TestDialect>();
ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
context.getOrCreateDialect<TestDialect>();
ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
}
} // end namespace

View File

@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context,
namespace {
TEST(OperandStorageTest, NonResizable) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
Operation *useOp =
@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) {
}
TEST(OperandStorageTest, Resizable) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
Operation *useOp =
@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) {
}
TEST(OperandStorageTest, RangeReplace) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
Operation *useOp =
@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) {
}
TEST(OperandStorageTest, MutableRange) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
Operation *useOp =

View File

@ -29,7 +29,7 @@ struct OpSpecificAnalysis {
};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
MLIRContext context(false);
MLIRContext context;
// Test fine grain invalidation of the module analysis manager.
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
@ -50,7 +50,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
// Create a function and a module.
@ -79,7 +79,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
// Create a function and a module.
@ -122,7 +122,7 @@ struct CustomInvalidatingAnalysis {
};
TEST(AnalysisManagerTest, CustomInvalidation) {
MLIRContext context(false);
MLIRContext context;
Builder builder(&context);
// Create a function and a module.

View File

@ -17,17 +17,18 @@
using namespace mlir;
/// Load the SDBM dialect.
static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
static thread_local MLIRContext context(false);
context.getOrLoadDialect<SDBMDialect>();
static thread_local MLIRContext context;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
d = ctx()->getOrLoadDialect<SDBMDialect>();
d = ctx()->getRegisteredDialect<SDBMDialect>();
}
return d;
}

View File

@ -25,16 +25,11 @@ namespace mlir {
// Test Fixture
//===----------------------------------------------------------------------===//
static MLIRContext &getContext() {
static MLIRContext ctx(false);
ctx.getOrLoadDialect<TestDialect>();
return ctx;
}
/// Test fixture for providing basic utilities for testing.
class OpBuildGenTest : public ::testing::Test {
protected:
OpBuildGenTest()
: ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()),
: ctx{}, builder(&ctx), loc(builder.getUnknownLoc()),
i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()),
cstI32(builder.create<TableGenConstant>(loc, i32Ty)),
cstF32(builder.create<TableGenConstant>(loc, f32Ty)),
@ -91,7 +86,7 @@ protected:
}
protected:
MLIRContext &ctx;
MLIRContext ctx;
OpBuilder builder;
Location loc;
Type i32Ty;

View File

@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
/// Validates that test::TestStruct::classof correctly identifies a valid
/// test::TestStruct.
TEST(StructsGenTest, ClassofTrue) {
mlir::MLIRContext context(false);
mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
ASSERT_TRUE(test::TestStruct::classof(structAttr));
}