forked from OSchip/llvm-project
Revert "Separate the Registration from Loading dialects in the Context"
This reverts commit e1de2b7550
.
Broke a build bot.
This commit is contained in:
parent
4cbceb74bb
commit
d84fe55e0d
|
@ -24,16 +24,9 @@
|
|||
int main(int argc, char **argv) {
|
||||
mlir::registerAllDialects();
|
||||
mlir::registerAllPasses();
|
||||
|
||||
mlir::registerDialect<mlir::standalone::StandaloneDialect>();
|
||||
// TODO: Register standalone passes here.
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerDialect<mlir::standalone::StandaloneDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
// Add the following to include *all* MLIR Core dialects, or selectively
|
||||
// include what you need like above. You only need to register dialects that
|
||||
// will be *parsed* by the tool, not the one generated
|
||||
// registerAllDialects(registry);
|
||||
|
||||
return failed(
|
||||
mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n", registry));
|
||||
return failed(mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n"));
|
||||
}
|
||||
|
|
|
@ -68,9 +68,10 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
|
||||
// Handle '.toy' input to the compiler.
|
||||
if (inputType != InputType::MLIR &&
|
||||
|
|
|
@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
|||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
llvm::SourceMgr sourceMgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
|
||||
|
|
|
@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
|||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
llvm::SourceMgr sourceMgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
|
||||
|
|
|
@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
|||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
llvm::SourceMgr sourceMgr;
|
||||
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
|
||||
|
|
|
@ -255,9 +255,6 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -159,9 +159,6 @@ private:
|
|||
namespace {
|
||||
struct ToyToLLVMLoweringPass
|
||||
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
|
||||
}
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -255,10 +255,10 @@ int main(int argc, char **argv) {
|
|||
|
||||
// If we aren't dumping the AST, then we are compiling with/to MLIR.
|
||||
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
if (int error = loadAndProcessMLIR(context, module))
|
||||
return error;
|
||||
|
|
|
@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
|
|||
namespace {
|
||||
struct ToyToAffineLoweringPass
|
||||
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, StandardOpsDialect>();
|
||||
}
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -159,9 +159,6 @@ private:
|
|||
namespace {
|
||||
struct ToyToLLVMLoweringPass
|
||||
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
|
||||
}
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -256,10 +256,10 @@ int main(int argc, char **argv) {
|
|||
|
||||
// If we aren't dumping the AST, then we are compiling with/to MLIR.
|
||||
|
||||
mlir::MLIRContext context(/*loadAllDialects=*/false);
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::toy::ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
if (int error = loadAndProcessMLIR(context, module))
|
||||
return error;
|
||||
|
|
|
@ -88,12 +88,6 @@ MlirContext mlirContextCreate();
|
|||
/** Takes an MLIR context owned by the caller and destroys it. */
|
||||
void mlirContextDestroy(MlirContext context);
|
||||
|
||||
/** Load all the globally registered dialects in the provided context.
|
||||
* TODO: remove the concept of globally registered dialect by exposing the
|
||||
* DialectRegistry.
|
||||
*/
|
||||
void mlirContextLoadAllDialects(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Location API. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -66,11 +66,6 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
|
|||
`affine.apply`.
|
||||
}];
|
||||
let constructor = "mlir::createLowerAffinePass()";
|
||||
let dependentDialects = [
|
||||
"scf::SCFDialect",
|
||||
"StandardOpsDialect",
|
||||
"vector::VectorDialect"
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -81,7 +76,6 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
|
|||
let summary = "Convert the operations from the avx512 dialect into the LLVM "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertAVX512ToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -104,7 +98,6 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
|
|||
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
let summary = "Generate NVVM operations for gpu operations";
|
||||
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
|
||||
let dependentDialects = ["NVVM::NVVMDialect"];
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
|
@ -119,7 +112,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
|||
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
|
||||
let summary = "Generate ROCDL operations for gpu operations";
|
||||
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
|
||||
let dependentDialects = ["ROCDL::ROCDLDialect"];
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
|
@ -134,7 +126,6 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
|
|||
def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
|
||||
let summary = "Convert GPU dialect to SPIR-V dialect";
|
||||
let constructor = "mlir::createConvertGPUToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -145,7 +136,6 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
|
|||
: Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> {
|
||||
let summary = "Convert gpu.launch_func to vulkanLaunch external call";
|
||||
let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
def ConvertVulkanLaunchFuncToVulkanCalls
|
||||
|
@ -153,7 +143,6 @@ def ConvertVulkanLaunchFuncToVulkanCalls
|
|||
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
|
||||
"calls";
|
||||
let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -164,7 +153,6 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
|
|||
let summary = "Convert the operations from the linalg dialect into the LLVM "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertLinalgToLLVMPass()";
|
||||
let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -175,7 +163,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
|
|||
let summary = "Convert the operations from the linalg dialect into the "
|
||||
"Standard dialect";
|
||||
let constructor = "mlir::createConvertLinalgToStandardPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -185,7 +172,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
|
|||
def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
|
||||
let summary = "Convert Linalg ops to SPIR-V ops";
|
||||
let constructor = "mlir::createLinalgToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -196,7 +182,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
|
|||
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
|
||||
" control flow with a CFG";
|
||||
let constructor = "mlir::createLowerToCFGPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -206,7 +191,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
|
|||
def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
|
||||
let summary = "Convert top-level AffineFor Ops to GPU kernels";
|
||||
let constructor = "mlir::createAffineForToGPUPass()";
|
||||
let dependentDialects = ["gpu::GPUDialect"];
|
||||
let options = [
|
||||
Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u",
|
||||
"Number of GPU block dimensions for mapping">,
|
||||
|
@ -218,7 +202,6 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
|
|||
def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
|
||||
let summary = "Convert mapped scf.parallel ops to gpu launch operations";
|
||||
let constructor = "mlir::createParallelLoopToGpuPass()";
|
||||
let dependentDialects = ["AffineDialect", "gpu::GPUDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -229,7 +212,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
|
|||
let summary = "Convert operations from the shape dialect into the standard "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertShapeToStandardPass()";
|
||||
let dependentDialects = ["StandardOpsDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -239,7 +221,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
|
|||
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
|
||||
let summary = "Convert operations from the shape dialect to the SCF dialect";
|
||||
let constructor = "mlir::createConvertShapeToSCFPass()";
|
||||
let dependentDialects = ["scf::SCFDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -249,7 +230,6 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
|
|||
def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
|
||||
let summary = "Convert SPIR-V dialect to LLVM dialect";
|
||||
let constructor = "mlir::createConvertSPIRVToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -284,7 +264,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
|
|||
LLVM IR types.
|
||||
}];
|
||||
let constructor = "mlir::createLowerToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
let options = [
|
||||
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
|
||||
"Use aligned_alloc in place of malloc for heap allocations">,
|
||||
|
@ -312,13 +291,11 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
|
|||
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
|
||||
let summary = "Legalize standard ops for SPIR-V lowering";
|
||||
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
|
||||
let summary = "Convert Standard Ops to SPIR-V dialect";
|
||||
let constructor = "mlir::createConvertStandardToSPIRVPass()";
|
||||
let dependentDialects = ["spirv::SPIRVDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -329,7 +306,6 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
|
|||
let summary = "Lower the operations from the vector dialect into the SCF "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertVectorToSCFPass()";
|
||||
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
|
||||
let options = [
|
||||
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
|
||||
"Perform full unrolling when converting vector transfers to SCF">,
|
||||
|
@ -344,7 +320,6 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
|
|||
let summary = "Lower the operations from the vector dialect into the LLVM "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertVectorToLLVMPass()";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
let options = [
|
||||
Option<"reassociateFPReductions", "reassociate-fp-reductions",
|
||||
"bool", /*default=*/"false",
|
||||
|
@ -360,7 +335,6 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
|
|||
let summary = "Lower the operations from the vector dialect into the ROCDL "
|
||||
"dialect";
|
||||
let constructor = "mlir::createConvertVectorToROCDLPass()";
|
||||
let dependentDialects = ["ROCDL::ROCDLDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_CONVERSION_PASSES
|
||||
|
|
|
@ -94,7 +94,6 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> {
|
|||
def AffineVectorize : FunctionPass<"affine-super-vectorize"> {
|
||||
let summary = "Vectorize to a target independent n-D vector abstraction";
|
||||
let constructor = "mlir::createSuperVectorizePass()";
|
||||
let dependentDialects = ["vector::VectorDialect"];
|
||||
let options = [
|
||||
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
|
||||
"Specify an n-D virtual vector size for vectorization",
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -19,11 +19,6 @@ include "mlir/IR/OpBase.td"
|
|||
def LLVM_Dialect : Dialect {
|
||||
let name = "llvm";
|
||||
let cppNamespace = "LLVM";
|
||||
|
||||
/// FIXME: at the moment this is a dependency of the translation to LLVM IR,
|
||||
/// not really one of this dialect per-se.
|
||||
let dependentDialects = ["omp::OpenMPDialect"];
|
||||
|
||||
let hasRegionArgAttrVerify = 1;
|
||||
let hasOperationAttrVerify = 1;
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
|
||||
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
def NVVM_Dialect : Dialect {
|
||||
let name = "nvvm";
|
||||
let cppNamespace = "NVVM";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
|
||||
#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
|
|
@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
def ROCDL_Dialect : Dialect {
|
||||
let name = "rocdl";
|
||||
let cppNamespace = "ROCDL";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -30,20 +30,17 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> {
|
|||
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
|
||||
let summary = "Fuse operations on RankedTensorType in linalg dialect";
|
||||
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
|
||||
let dependentDialects = ["AffineDialect"];
|
||||
}
|
||||
|
||||
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
|
||||
let summary = "Lower the operations from the linalg dialect into affine "
|
||||
"loops";
|
||||
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
|
||||
let dependentDialects = ["AffineDialect"];
|
||||
}
|
||||
|
||||
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
|
||||
let summary = "Lower the operations from the linalg dialect into loops";
|
||||
let constructor = "mlir::createConvertLinalgToLoopsPass()";
|
||||
let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
|
||||
}
|
||||
|
||||
def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
|
||||
|
@ -57,7 +54,6 @@ def LinalgLowerToParallelLoops
|
|||
let summary = "Lower the operations from the linalg dialect into parallel "
|
||||
"loops";
|
||||
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
|
||||
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
|
||||
|
@ -74,9 +70,6 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
|
|||
def LinalgTiling : FunctionPass<"linalg-tile"> {
|
||||
let summary = "Tile operations in the linalg dialect";
|
||||
let constructor = "mlir::createLinalgTilingPass()";
|
||||
let dependentDialects = [
|
||||
"AffineDialect", "scf::SCFDialect"
|
||||
];
|
||||
let options = [
|
||||
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
|
||||
"Test generation of dynamic promoted buffers",
|
||||
|
@ -93,7 +86,6 @@ def LinalgTilingToParallelLoops
|
|||
"Test generation of dynamic promoted buffers",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
|
||||
];
|
||||
let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_PASSES
|
||||
|
|
|
@ -36,7 +36,6 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
|
|||
"Factors to tile parallel loops by",
|
||||
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
|
||||
];
|
||||
let dependentDialects = ["AffineDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_SCF_PASSES
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
namespace mlir {
|
||||
class DialectAsmParser;
|
||||
class DialectAsmPrinter;
|
||||
|
@ -25,7 +23,7 @@ class DialectInterface;
|
|||
class OpBuilder;
|
||||
class Type;
|
||||
|
||||
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
|
||||
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||
/// entire group. For example, hooks into other systems for constant folding,
|
||||
|
@ -214,87 +212,30 @@ private:
|
|||
/// A collection of registered dialect interfaces.
|
||||
DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
|
||||
|
||||
/// Registers a specific dialect creation function with the global registry.
|
||||
/// Used through the registerDialect template.
|
||||
/// Registrations are deduplicated by dialect TypeID and only the first
|
||||
/// registration will be used.
|
||||
static void
|
||||
registerDialectAllocator(TypeID typeID,
|
||||
const DialectAllocatorFunction &function);
|
||||
template <typename ConcreteDialect>
|
||||
friend void registerDialect();
|
||||
friend class MLIRContext;
|
||||
};
|
||||
|
||||
/// The DialectRegistry maps a dialect namespace to a constructor for the
|
||||
/// matching dialect.
|
||||
/// This allows for decoupling the list of dialects "available" from the
|
||||
/// dialects loaded in the Context. The parser in particular will lazily load
|
||||
/// dialects in in the Context as operations are encountered.
|
||||
class DialectRegistry {
|
||||
using MapTy =
|
||||
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
|
||||
|
||||
public:
|
||||
template <typename ConcreteDialect>
|
||||
void insert() {
|
||||
insert(TypeID::get<ConcreteDialect>(),
|
||||
ConcreteDialect::getDialectNamespace(),
|
||||
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
|
||||
// Just allocate the dialect, the context
|
||||
// takes ownership of it.
|
||||
return ctx->getOrLoadDialect<ConcreteDialect>();
|
||||
})));
|
||||
}
|
||||
|
||||
template <typename ConcreteDialect, typename OtherDialect,
|
||||
typename... MoreDialects>
|
||||
void insert() {
|
||||
insert<ConcreteDialect>();
|
||||
insert<OtherDialect, MoreDialects...>();
|
||||
}
|
||||
|
||||
/// Add a new dialect constructor to the registry.
|
||||
void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
|
||||
|
||||
/// Load a dialect for this namespace in the provided context.
|
||||
Dialect *loadByName(StringRef name, MLIRContext *context);
|
||||
|
||||
// Register all dialects available in the current registry with the registry
|
||||
// in the provided context.
|
||||
void appendTo(DialectRegistry &destination) {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
destination.insert(nameAndRegistrationIt.second.first,
|
||||
nameAndRegistrationIt.first,
|
||||
nameAndRegistrationIt.second.second);
|
||||
}
|
||||
// Load all dialects available in the registry in the provided context.
|
||||
void loadAll(MLIRContext *context) {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
nameAndRegistrationIt.second.second(context);
|
||||
}
|
||||
|
||||
MapTy::const_iterator begin() const { return registry.begin(); }
|
||||
MapTy::const_iterator end() const { return registry.end(); }
|
||||
|
||||
private:
|
||||
MapTy registry;
|
||||
};
|
||||
|
||||
/// Deprecated: this provides a global registry for convenience, while we're
|
||||
/// transitionning the registration mechanism to a stateless approach.
|
||||
DialectRegistry &getGlobalDialectRegistry();
|
||||
|
||||
/// Registers all dialects from the global registries with the
|
||||
/// specified MLIRContext. This won't load the dialects in the context,
|
||||
/// but only make them available for lazy loading by name.
|
||||
/// Registers all dialects and hooks from the global registries with the
|
||||
/// specified MLIRContext.
|
||||
/// Note: This method is not thread-safe.
|
||||
void registerAllDialects(MLIRContext *context);
|
||||
|
||||
/// Register and return the dialect with the given namespace in the provided
|
||||
/// context. Returns nullptr is there is no constructor registered for this
|
||||
/// dialect.
|
||||
inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
|
||||
return getGlobalDialectRegistry().loadByName(name, context);
|
||||
}
|
||||
|
||||
/// Utility to register a dialect. Client can register their dialect with the
|
||||
/// global registry by calling registerDialect<MyDialect>();
|
||||
/// Note: This method is not thread-safe.
|
||||
template <typename ConcreteDialect> void registerDialect() {
|
||||
getGlobalDialectRegistry().insert<ConcreteDialect>();
|
||||
Dialect::registerDialectAllocator(
|
||||
TypeID::get<ConcreteDialect>(),
|
||||
[](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
|
||||
}
|
||||
|
||||
/// DialectRegistration provides a global initializer that registers a Dialect
|
||||
|
|
|
@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
|
|||
if (!attr.first.strref().contains('.'))
|
||||
return funcOp.emitOpError("arguments may only have dialect attributes");
|
||||
auto dialectNamePair = attr.first.strref().split('.');
|
||||
if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
|
||||
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
|
||||
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
|
||||
/*argIndex=*/i, attr)))
|
||||
return failure();
|
||||
|
@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
|
|||
if (!attr.first.strref().contains('.'))
|
||||
return funcOp.emitOpError("results may only have dialect attributes");
|
||||
auto dialectNamePair = attr.first.strref().split('.');
|
||||
if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
|
||||
if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
|
||||
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
|
||||
/*resultIndex=*/i,
|
||||
attr)))
|
||||
|
|
|
@ -19,12 +19,10 @@ namespace mlir {
|
|||
class AbstractOperation;
|
||||
class DiagnosticEngine;
|
||||
class Dialect;
|
||||
class DialectRegistry;
|
||||
class InFlightDiagnostic;
|
||||
class Location;
|
||||
class MLIRContextImpl;
|
||||
class StorageUniquer;
|
||||
DialectRegistry &getGlobalDialectRegistry();
|
||||
|
||||
/// MLIRContext is the top-level object for a collection of MLIR modules. It
|
||||
/// holds immortal uniqued objects like types, and the tables used to unique
|
||||
|
@ -36,69 +34,34 @@ DialectRegistry &getGlobalDialectRegistry();
|
|||
///
|
||||
class MLIRContext {
|
||||
public:
|
||||
/// Create a new Context.
|
||||
/// The loadAllDialects parameters allows to load all dialects from the global
|
||||
/// registry on Context construction. It is deprecated and will be removed
|
||||
/// soon.
|
||||
explicit MLIRContext(bool loadAllDialects = true);
|
||||
explicit MLIRContext();
|
||||
~MLIRContext();
|
||||
|
||||
/// Return information about all IR dialects loaded in the context.
|
||||
std::vector<Dialect *> getLoadedDialects();
|
||||
|
||||
/// Return the dialect registry associated with this context.
|
||||
DialectRegistry &getDialectRegistry();
|
||||
|
||||
/// Return information about all available dialects in the registry in this
|
||||
/// context.
|
||||
std::vector<StringRef> getAvailableDialects();
|
||||
/// Return information about all registered IR dialects.
|
||||
std::vector<Dialect *> getRegisteredDialects();
|
||||
|
||||
/// Get a registered IR dialect with the given namespace. If an exact match is
|
||||
/// not found, then return nullptr.
|
||||
Dialect *getLoadedDialect(StringRef name);
|
||||
Dialect *getRegisteredDialect(StringRef name);
|
||||
|
||||
/// Get a registered IR dialect for the given derived dialect type. The
|
||||
/// derived type must provide a static 'getDialectNamespace' method.
|
||||
template <typename T>
|
||||
T *getLoadedDialect() {
|
||||
return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
|
||||
template <typename T> T *getRegisteredDialect() {
|
||||
return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
|
||||
}
|
||||
|
||||
/// Get (or create) a dialect for the given derived dialect type. The derived
|
||||
/// type must provide a static 'getDialectNamespace' method.
|
||||
template <typename T>
|
||||
T *getOrLoadDialect() {
|
||||
return static_cast<T *>(
|
||||
getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
|
||||
T *getOrCreateDialect() {
|
||||
return static_cast<T *>(getOrCreateDialect(
|
||||
T::getDialectNamespace(), TypeID::get<T>(), [this]() {
|
||||
std::unique_ptr<T> dialect(new T(this));
|
||||
dialect->dialectID = TypeID::get<T>();
|
||||
return dialect;
|
||||
}));
|
||||
}
|
||||
|
||||
/// Load a dialect in the context.
|
||||
template <typename Dialect>
|
||||
void loadDialect() {
|
||||
getOrLoadDialect<Dialect>();
|
||||
}
|
||||
|
||||
/// Load a list dialects in the context.
|
||||
template <typename Dialect, typename OtherDialect, typename... MoreDialects>
|
||||
void loadDialect() {
|
||||
getOrLoadDialect<Dialect>();
|
||||
loadDialect<OtherDialect, MoreDialects...>();
|
||||
}
|
||||
|
||||
/// Deprecated: load all globally registered dialects into this context.
|
||||
/// This method will be removed soon, it can be used temporarily as we're
|
||||
/// phasing out the global registry.
|
||||
void loadAllGloballyRegisteredDialects();
|
||||
|
||||
/// Get (or create) a dialect for the given derived dialect name.
|
||||
/// The dialect will be loaded from the registry if no dialect is found.
|
||||
/// If no dialect is loaded for this name and none is available in the
|
||||
/// registry, returns nullptr.
|
||||
Dialect *getOrLoadDialect(StringRef name);
|
||||
|
||||
/// Return true if we allow to create operation for unregistered dialects.
|
||||
bool allowsUnregisteredDialects();
|
||||
|
||||
|
@ -160,12 +123,10 @@ private:
|
|||
const std::unique_ptr<MLIRContextImpl> impl;
|
||||
|
||||
/// Get a dialect for the provided namespace and TypeID: abort the program if
|
||||
/// a dialect exist for this namespace with different TypeID. If a dialect has
|
||||
/// not been loaded for this namespace/TypeID yet, use the provided ctor to
|
||||
/// create one on the fly and load it. Returns a pointer to the dialect owned
|
||||
/// by the context.
|
||||
Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
function_ref<std::unique_ptr<Dialect>()> ctor);
|
||||
/// a dialect exist for this namespace with different TypeID. Returns a
|
||||
/// pointer to the dialect owned by the context.
|
||||
Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
function_ref<std::unique_ptr<Dialect>()> ctor);
|
||||
|
||||
MLIRContext(const MLIRContext &) = delete;
|
||||
void operator=(const MLIRContext &) = delete;
|
||||
|
|
|
@ -244,11 +244,6 @@ class Dialect {
|
|||
// The description of the dialect.
|
||||
string description = ?;
|
||||
|
||||
// A list of dialects this dialect will load on construction as dependencies.
|
||||
// These are dialects that this dialect may involved in canonicalization
|
||||
// pattern or interfaces.
|
||||
list<string> dependentDialects = [];
|
||||
|
||||
// The C++ namespace that ops of this dialect should be placed into.
|
||||
//
|
||||
// By default, uses the name of the dialect as the only namespace. To avoid
|
||||
|
|
|
@ -35,35 +35,30 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
// Add all the MLIR dialects to the provided registry.
|
||||
inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<acc::OpenACCDialect,
|
||||
AffineDialect,
|
||||
avx512::AVX512Dialect,
|
||||
gpu::GPUDialect,
|
||||
LLVM::LLVMAVX512Dialect,
|
||||
LLVM::LLVMDialect,
|
||||
linalg::LinalgDialect,
|
||||
scf::SCFDialect,
|
||||
omp::OpenMPDialect,
|
||||
quant::QuantizationDialect,
|
||||
spirv::SPIRVDialect,
|
||||
StandardOpsDialect,
|
||||
vector::VectorDialect,
|
||||
NVVM::NVVMDialect,
|
||||
ROCDL::ROCDLDialect,
|
||||
SDBMDialect,
|
||||
shape::ShapeDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// This function should be called before creating any MLIRContext if one expect
|
||||
// all the possible dialects to be made available to the context automatically.
|
||||
inline void registerAllDialects() {
|
||||
static bool initOnce =
|
||||
([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
|
||||
(void)initOnce;
|
||||
static bool init_once = []() {
|
||||
registerDialect<acc::OpenACCDialect>();
|
||||
registerDialect<AffineDialect>();
|
||||
registerDialect<avx512::AVX512Dialect>();
|
||||
registerDialect<gpu::GPUDialect>();
|
||||
registerDialect<LLVM::LLVMAVX512Dialect>();
|
||||
registerDialect<LLVM::LLVMDialect>();
|
||||
registerDialect<linalg::LinalgDialect>();
|
||||
registerDialect<scf::SCFDialect>();
|
||||
registerDialect<omp::OpenMPDialect>();
|
||||
registerDialect<quant::QuantizationDialect>();
|
||||
registerDialect<spirv::SPIRVDialect>();
|
||||
registerDialect<StandardOpsDialect>();
|
||||
registerDialect<vector::VectorDialect>();
|
||||
registerDialect<NVVM::NVVMDialect>();
|
||||
registerDialect<ROCDL::ROCDLDialect>();
|
||||
registerDialect<SDBMDialect>();
|
||||
registerDialect<shape::ShapeDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ void registerAVX512ToLLVMIRTranslation();
|
|||
// expects all the possible translations to be made available to the context
|
||||
// automatically.
|
||||
inline void registerAllTranslations() {
|
||||
static bool initOnce = []() {
|
||||
static bool init_once = []() {
|
||||
registerFromLLVMIRTranslation();
|
||||
registerFromSPIRVTranslation();
|
||||
registerToLLVMIRTranslation();
|
||||
|
@ -38,7 +38,7 @@ inline void registerAllTranslations() {
|
|||
registerAVX512ToLLVMIRTranslation();
|
||||
return true;
|
||||
}();
|
||||
(void)initOnce;
|
||||
(void)init_once;
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#ifndef MLIR_PASS_PASS_H
|
||||
#define MLIR_PASS_PASS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/Pass/AnalysisManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
@ -58,13 +57,6 @@ public:
|
|||
/// Returns the derived pass name.
|
||||
virtual StringRef getName() const = 0;
|
||||
|
||||
/// Register dependent dialects for the current pass.
|
||||
/// A pass is expected to register the dialects it will create entities for
|
||||
/// (Operations, Types, Attributes), other than dialect that exists in the
|
||||
/// input. For example, a pass that converts from Linalg to Affine would
|
||||
/// register the Affine dialect but does not need to register Linalg.
|
||||
virtual void getDependentDialects(DialectRegistry ®istry) const {}
|
||||
|
||||
/// Returns the command line argument used when registering this pass. Return
|
||||
/// an empty string if one does not exist.
|
||||
virtual StringRef getArgument() const {
|
||||
|
|
|
@ -78,9 +78,6 @@ class PassBase<string passArg, string base> {
|
|||
// A C++ constructor call to create an instance of this pass.
|
||||
code constructor = [{}];
|
||||
|
||||
// A list of dialects this pass may produce entities in.
|
||||
list<string> dependentDialects = [];
|
||||
|
||||
// A set of options provided by this pass.
|
||||
list<Option> options = [];
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#ifndef MLIR_PASS_PASSMANAGER_H
|
||||
#define MLIR_PASS_PASSMANAGER_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
@ -59,14 +58,6 @@ public:
|
|||
pass_iterator end();
|
||||
iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
|
||||
|
||||
using const_pass_iterator = llvm::pointee_iterator<
|
||||
std::vector<std::unique_ptr<Pass>>::const_iterator>;
|
||||
const_pass_iterator begin() const;
|
||||
const_pass_iterator end() const;
|
||||
iterator_range<const_pass_iterator> getPasses() const {
|
||||
return {begin(), end()};
|
||||
}
|
||||
|
||||
/// Run the held passes over the given operation.
|
||||
LogicalResult run(Operation *op, AnalysisManager am);
|
||||
|
||||
|
@ -109,11 +100,6 @@ public:
|
|||
/// Merge the pass statistics of this class into 'other'.
|
||||
void mergeStatisticsInto(OpPassManager &other);
|
||||
|
||||
/// Register dependent dialects for the current pass manager.
|
||||
/// This is forwarding to every pass in this PassManager, see the
|
||||
/// documentation for the same method on the Pass class.
|
||||
void getDependentDialects(DialectRegistry &dialects) const;
|
||||
|
||||
private:
|
||||
OpPassManager(OperationName name, bool verifyPasses);
|
||||
|
||||
|
|
|
@ -21,14 +21,12 @@ class MemoryBuffer;
|
|||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
class PassPipelineCLParser;
|
||||
|
||||
/// Perform the core processing behind `mlir-opt`:
|
||||
/// - outputStream is the stream where the resulting IR is printed.
|
||||
/// - buffer is the in-memory file to parser and process.
|
||||
/// - passPipeline is the specification of the pipeline that will be applied.
|
||||
/// - registry should contain all the dialects that can be parsed in the source.
|
||||
/// - splitInputFile will look for a "-----" marker in the input file, and load
|
||||
/// each chunk in an individual ModuleOp processed separately.
|
||||
/// - verifyDiagnostics enables a verification mode where comments starting with
|
||||
|
@ -37,25 +35,13 @@ class PassPipelineCLParser;
|
|||
/// - verifyPasses enables the IR verifier in-between each pass in the pipeline.
|
||||
/// - allowUnregisteredDialects allows to parse and create operation without
|
||||
/// registering the Dialect in the MLIRContext.
|
||||
/// - preloadDialectsInContext will trigger the upfront loading of all
|
||||
/// dialects from the global registry in the MLIRContext. This option is
|
||||
/// deprecated and will be removed soon.
|
||||
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
|
||||
std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
const PassPipelineCLParser &passPipeline,
|
||||
DialectRegistry ®istry, bool splitInputFile,
|
||||
bool verifyDiagnostics, bool verifyPasses,
|
||||
bool allowUnregisteredDialects,
|
||||
bool preloadDialectsInContext = true);
|
||||
bool splitInputFile, bool verifyDiagnostics,
|
||||
bool verifyPasses, bool allowUnregisteredDialects);
|
||||
|
||||
/// Implementation for tools like `mlir-opt`.
|
||||
/// - toolName is used for the header displayed by `--help`.
|
||||
/// - registry should contain all the dialects that can be parsed in the source.
|
||||
/// - preloadDialectsInContext will trigger the upfront loading of all
|
||||
/// dialects from the global registry in the MLIRContext. This option is
|
||||
/// deprecated and will be removed soon.
|
||||
LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
|
||||
DialectRegistry ®istry,
|
||||
bool preloadDialectsInContext = true);
|
||||
LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName);
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
|
@ -26,7 +25,7 @@ namespace tblgen {
|
|||
// and provides helper methods for accessing them.
|
||||
class Dialect {
|
||||
public:
|
||||
explicit Dialect(const llvm::Record *def);
|
||||
explicit Dialect(const llvm::Record *def) : def(def) {}
|
||||
|
||||
// Returns the name of this dialect.
|
||||
StringRef getName() const;
|
||||
|
@ -44,10 +43,6 @@ public:
|
|||
// Returns the description of the dialect. Returns empty string if none.
|
||||
StringRef getDescription() const;
|
||||
|
||||
// Returns the list of dialect (class names) that this dialect depends on.
|
||||
// These are dialects that will be loaded on construction of this dialect.
|
||||
ArrayRef<StringRef> getDependentDialects() const;
|
||||
|
||||
// Returns the dialects extra class declaration code.
|
||||
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||
|
||||
|
@ -75,7 +70,6 @@ public:
|
|||
|
||||
private:
|
||||
const llvm::Record *def;
|
||||
std::vector<StringRef> dependentDialects;
|
||||
};
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -94,9 +94,6 @@ public:
|
|||
/// Return the C++ constructor call to create an instance of this pass.
|
||||
StringRef getConstructor() const;
|
||||
|
||||
/// Return the dialects this pass needs to be registered.
|
||||
ArrayRef<StringRef> getDependentDialects() const;
|
||||
|
||||
/// Return the options provided by this pass.
|
||||
ArrayRef<PassOption> getOptions() const;
|
||||
|
||||
|
@ -107,7 +104,6 @@ public:
|
|||
|
||||
private:
|
||||
const llvm::Record *def;
|
||||
std::vector<StringRef> dependentDialects;
|
||||
std::vector<PassOption> options;
|
||||
std::vector<PassStatistic> statistics;
|
||||
};
|
||||
|
|
|
@ -162,8 +162,6 @@ def BufferPlacement : FunctionPass<"buffer-placement"> {
|
|||
|
||||
}];
|
||||
let constructor = "mlir::createBufferPlacementPass()";
|
||||
// TODO: this pass likely shouldn't depend on Linalg?
|
||||
let dependentDialects = ["linalg::LinalgDialect"];
|
||||
}
|
||||
|
||||
def Canonicalizer : Pass<"canonicalize"> {
|
||||
|
|
|
@ -10,11 +10,9 @@
|
|||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -52,17 +50,12 @@ private:
|
|||
/* ========================================================================== */
|
||||
|
||||
MlirContext mlirContextCreate() {
|
||||
auto *context = new MLIRContext(/*loadAllDialects=*/false);
|
||||
auto *context = new MLIRContext;
|
||||
return wrap(context);
|
||||
}
|
||||
|
||||
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
|
||||
|
||||
void mlirContextLoadAllDialects(MlirContext context) {
|
||||
registerAllDialects(unwrap(context));
|
||||
getGlobalDialectRegistry().loadAll(unwrap(context));
|
||||
}
|
||||
|
||||
/* ========================================================================== */
|
||||
/* Location API. */
|
||||
/* ========================================================================== */
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Serialization.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
|
|
@ -12,43 +12,11 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineDialect;
|
||||
class StandardOpsDialect;
|
||||
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
namespace gpu {
|
||||
class GPUDialect;
|
||||
class GPUModuleOp;
|
||||
} // end namespace gpu
|
||||
|
||||
namespace LLVM {
|
||||
class LLVMDialect;
|
||||
class LLVMAVX512Dialect;
|
||||
} // end namespace LLVM
|
||||
|
||||
namespace NVVM {
|
||||
class NVVMDialect;
|
||||
} // end namespace NVVM
|
||||
|
||||
namespace ROCDL {
|
||||
class ROCDLDialect;
|
||||
} // end namespace ROCDL
|
||||
|
||||
namespace scf {
|
||||
class SCFDialect;
|
||||
} // end namespace scf
|
||||
|
||||
namespace spirv {
|
||||
class SPIRVDialect;
|
||||
} // end namespace spirv
|
||||
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // end namespace vector
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
|
|||
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
|
||||
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
|
||||
const LowerToLLVMOptions &options)
|
||||
: llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
|
||||
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
|
||||
options(options) {
|
||||
assert(llvmDialect && "LLVM IR dialect is not registered");
|
||||
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
|
|
@ -12,16 +12,6 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
namespace linalg {
|
||||
class LinalgDialect;
|
||||
} // end namespace linalg
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // end namespace vector
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Affine/Passes.h.inc"
|
||||
|
|
|
@ -1244,7 +1244,6 @@ template <typename NamedStructuredOpType>
|
|||
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
|
||||
result.getContext()->getOrLoadDialect<StandardOpsDialect>();
|
||||
|
||||
// Optional attributes may be added.
|
||||
if (parser.parseOperandList(operandsInfo) ||
|
||||
|
|
|
@ -9,18 +9,9 @@
|
|||
#ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
|
||||
#define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
namespace scf {
|
||||
class SCFDialect;
|
||||
} // end namespace scf
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
||||
|
|
|
@ -12,11 +12,6 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
class AffineDialect;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/SCF/Passes.h.inc"
|
||||
|
|
|
@ -517,7 +517,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
|
|||
|
||||
SDBMDialect *dialect;
|
||||
} converter;
|
||||
converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
|
||||
converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
|
||||
|
||||
if (auto result = converter.visit(affine))
|
||||
return result;
|
||||
|
|
|
@ -259,9 +259,7 @@ int mlir::JitRunnerMain(
|
|||
}
|
||||
}
|
||||
|
||||
MLIRContext context(/*loadAllDialects=*/false);
|
||||
registerAllDialects(&context);
|
||||
|
||||
MLIRContext context;
|
||||
auto m = parseMLIRInput(options.inputFilename, &context);
|
||||
if (!m) {
|
||||
llvm::errs() << "could not parse the input IR\n";
|
||||
|
|
|
@ -27,29 +27,21 @@ DialectAsmParser::~DialectAsmParser() {}
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Registry for all dialect allocation functions.
|
||||
static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
|
||||
DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
|
||||
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
|
||||
dialectRegistry;
|
||||
|
||||
void Dialect::registerDialectAllocator(
|
||||
TypeID typeID, const DialectAllocatorFunction &function) {
|
||||
assert(function &&
|
||||
"Attempting to register an empty dialect initialize function");
|
||||
dialectRegistry->insert({typeID, function});
|
||||
}
|
||||
|
||||
/// Registers all dialects and hooks from the global registries with the
|
||||
/// specified MLIRContext.
|
||||
void mlir::registerAllDialects(MLIRContext *context) {
|
||||
dialectRegistry->appendTo(context->getDialectRegistry());
|
||||
}
|
||||
|
||||
Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
|
||||
auto it = registry.find(name.str());
|
||||
if (it == registry.end())
|
||||
return nullptr;
|
||||
return it->second.second(context);
|
||||
}
|
||||
|
||||
void DialectRegistry::insert(TypeID typeID, StringRef name,
|
||||
DialectAllocatorFunction ctor) {
|
||||
auto inserted =
|
||||
registry.insert(std::make_pair(name, std::make_pair(typeID, ctor)));
|
||||
if (!inserted.second && inserted.first->second.first != typeID) {
|
||||
llvm::report_fatal_error(
|
||||
"Trying to register different dialects for the same namespace: " +
|
||||
name);
|
||||
}
|
||||
for (const auto &it : *dialectRegistry)
|
||||
it.second(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -127,7 +119,7 @@ DialectInterface::~DialectInterface() {}
|
|||
|
||||
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
||||
MLIRContext *ctx, TypeID interfaceKind) {
|
||||
for (auto *dialect : ctx->getLoadedDialects()) {
|
||||
for (auto *dialect : ctx->getRegisteredDialects()) {
|
||||
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
|
||||
interfaces.insert(interface);
|
||||
orderedInterfaces.push_back(interface);
|
||||
|
|
|
@ -31,13 +31,10 @@
|
|||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/RWMutex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <memory>
|
||||
|
||||
#define DEBUG_TYPE "mlircontext"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
|
@ -278,8 +275,7 @@ public:
|
|||
|
||||
/// This is a list of dialects that are created referring to this context.
|
||||
/// The MLIRContext owns the objects.
|
||||
DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
|
||||
DialectRegistry dialectsRegistry;
|
||||
std::vector<std::unique_ptr<Dialect>> dialects;
|
||||
|
||||
/// This is a mapping from operation name to AbstractOperation for registered
|
||||
/// operations.
|
||||
|
@ -350,7 +346,7 @@ public:
|
|||
};
|
||||
} // end namespace mlir
|
||||
|
||||
MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
|
||||
MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
|
||||
// Initialize values based on the command line flags if they were provided.
|
||||
if (clOptions.isConstructed()) {
|
||||
disableMultithreading(clOptions->disableThreading);
|
||||
|
@ -359,9 +355,8 @@ MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
|
|||
}
|
||||
|
||||
// Register dialects with this context.
|
||||
getOrLoadDialect<BuiltinDialect>();
|
||||
if (loadAllDialects)
|
||||
loadAllGloballyRegisteredDialects();
|
||||
getOrCreateDialect<BuiltinDialect>();
|
||||
registerAllDialects(this);
|
||||
|
||||
// Initialize several common attributes and types to avoid the need to lock
|
||||
// the context when accessing them.
|
||||
|
@ -443,72 +438,54 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
|||
// Dialect and Operation Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectRegistry &MLIRContext::getDialectRegistry() {
|
||||
return impl->dialectsRegistry;
|
||||
}
|
||||
|
||||
/// Return information about all registered IR dialects.
|
||||
std::vector<Dialect *> MLIRContext::getLoadedDialects() {
|
||||
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
|
||||
std::vector<Dialect *> result;
|
||||
result.reserve(impl->loadedDialects.size());
|
||||
for (auto &dialect : impl->loadedDialects)
|
||||
result.push_back(dialect.second.get());
|
||||
llvm::array_pod_sort(result.begin(), result.end(),
|
||||
[](Dialect *const *lhs, Dialect *const *rhs) -> int {
|
||||
return (*lhs)->getNamespace() < (*rhs)->getNamespace();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
std::vector<StringRef> MLIRContext::getAvailableDialects() {
|
||||
std::vector<StringRef> result;
|
||||
for (auto &dialect : impl->dialectsRegistry)
|
||||
result.push_back(dialect.first);
|
||||
result.reserve(impl->dialects.size());
|
||||
for (auto &dialect : impl->dialects)
|
||||
result.push_back(dialect.get());
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Get a registered IR dialect with the given namespace. If none is found,
|
||||
/// then return nullptr.
|
||||
Dialect *MLIRContext::getLoadedDialect(StringRef name) {
|
||||
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
|
||||
// Dialects are sorted by name, so we can use binary search for lookup.
|
||||
auto it = impl->loadedDialects.find(name);
|
||||
return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
|
||||
}
|
||||
|
||||
Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
|
||||
Dialect *dialect = getLoadedDialect(name);
|
||||
if (dialect)
|
||||
return dialect;
|
||||
return impl->dialectsRegistry.loadByName(name, this);
|
||||
auto it = llvm::lower_bound(
|
||||
impl->dialects, name,
|
||||
[](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
|
||||
return (it != impl->dialects.end() && (*it)->getNamespace() == name)
|
||||
? (*it).get()
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
/// Get a dialect for the provided namespace and TypeID: abort the program if a
|
||||
/// dialect exist for this namespace with different TypeID. Returns a pointer to
|
||||
/// the dialect owned by the context.
|
||||
Dialect *
|
||||
MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
function_ref<std::unique_ptr<Dialect>()> ctor) {
|
||||
MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
function_ref<std::unique_ptr<Dialect>()> ctor) {
|
||||
auto &impl = getImpl();
|
||||
// Get the correct insertion position sorted by namespace.
|
||||
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
|
||||
|
||||
if (!dialect) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Load new dialect in Context" << dialectNamespace);
|
||||
dialect = ctor();
|
||||
assert(dialect && "dialect ctor failed");
|
||||
return dialect.get();
|
||||
}
|
||||
auto insertPt =
|
||||
llvm::lower_bound(impl.dialects, nullptr,
|
||||
[&](const std::unique_ptr<Dialect> &lhs,
|
||||
const std::unique_ptr<Dialect> &rhs) {
|
||||
if (!lhs)
|
||||
return dialectNamespace < rhs->getNamespace();
|
||||
return lhs->getNamespace() < dialectNamespace;
|
||||
});
|
||||
|
||||
// Abort if dialect with namespace has already been registered.
|
||||
if (dialect->getTypeID() != dialectID)
|
||||
if (insertPt != impl.dialects.end() &&
|
||||
(*insertPt)->getNamespace() == dialectNamespace) {
|
||||
if ((*insertPt)->getTypeID() == dialectID)
|
||||
return insertPt->get();
|
||||
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
|
||||
"' has already been registered");
|
||||
|
||||
return dialect.get();
|
||||
}
|
||||
|
||||
void MLIRContext::loadAllGloballyRegisteredDialects() {
|
||||
getGlobalDialectRegistry().loadAll(this);
|
||||
}
|
||||
auto it = impl.dialects.insert(insertPt, ctor());
|
||||
return &**it;
|
||||
}
|
||||
|
||||
bool MLIRContext::allowsUnregisteredDialects() {
|
||||
|
|
|
@ -214,7 +214,7 @@ Dialect *Operation::getDialect() {
|
|||
|
||||
// If this operation hasn't been registered or doesn't have abstract
|
||||
// operation, try looking up the dialect name in the context.
|
||||
return getContext()->getLoadedDialect(getName().getDialect());
|
||||
return getContext()->getRegisteredDialect(getName().getDialect());
|
||||
}
|
||||
|
||||
Region *Operation::getParentRegion() {
|
||||
|
|
|
@ -50,7 +50,7 @@ public:
|
|||
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
|
||||
assert(attr.first.strref().contains('.') && "expected dialect attribute");
|
||||
auto dialectNamePair = attr.first.strref().split('.');
|
||||
return ctx->getLoadedDialect(dialectNamePair.first);
|
||||
return ctx->getRegisteredDialect(dialectNamePair.first);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
|
|||
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
|
||||
if (it == dialectAllowsUnknownOps.end()) {
|
||||
// If the operation dialect is registered, query it directly.
|
||||
if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
|
||||
if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
|
||||
it = dialectAllowsUnknownOps
|
||||
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
|
||||
.first;
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
#include "Parser.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -247,11 +246,6 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
|
|||
return emitError("duplicate key in dictionary attribute");
|
||||
consumeToken();
|
||||
|
||||
// Lazy load a dialect in the context if there is a possible namespace.
|
||||
auto splitName = nameId->strref().split('.');
|
||||
if (!splitName.second.empty())
|
||||
getContext()->getOrLoadDialect(splitName.first);
|
||||
|
||||
// Try to parse the '=' for the attribute value.
|
||||
if (!consumeIf(Token::equal)) {
|
||||
// If there is no '=', we treat this as a unit attribute.
|
||||
|
@ -823,9 +817,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
|
|||
return (emitError("expected dialect namespace"), nullptr);
|
||||
|
||||
auto name = getToken().getStringValue();
|
||||
// Lazy load a dialect in the context if there is a possible namespace.
|
||||
Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
|
||||
|
||||
auto *dialect = builder.getContext()->getRegisteredDialect(name);
|
||||
// TODO: Allow for having an unknown dialect on an opaque
|
||||
// attribute. Otherwise, it can't be roundtripped without having the dialect
|
||||
// registered.
|
||||
|
|
|
@ -526,8 +526,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
|
|||
return Attribute();
|
||||
|
||||
// If we found a registered dialect, then ask it to parse the attribute.
|
||||
if (Dialect *dialect =
|
||||
builder.getContext()->getOrLoadDialect(dialectName)) {
|
||||
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
|
||||
return parseSymbol<Attribute>(
|
||||
symbolData, state.context, state.symbols, [&](Parser &parser) {
|
||||
CustomDialectAsmParser customParser(symbolData, parser);
|
||||
|
@ -564,9 +563,7 @@ Type Parser::parseExtendedType() {
|
|||
[&](StringRef dialectName, StringRef symbolData,
|
||||
llvm::SMLoc loc) -> Type {
|
||||
// If we found a registered dialect, then ask it to parse the type.
|
||||
auto *dialect = state.context->getOrLoadDialect(dialectName);
|
||||
|
||||
if (dialect) {
|
||||
if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
|
||||
return parseSymbol<Type>(
|
||||
symbolData, state.context, state.symbols, [&](Parser &parser) {
|
||||
CustomDialectAsmParser customParser(symbolData, parser);
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
|
||||
#include "Parser.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Parser.h"
|
||||
|
@ -728,7 +727,7 @@ Operation *OperationParser::parseGenericOperation() {
|
|||
// Get location information for the operation.
|
||||
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
|
||||
|
||||
std::string name = getToken().getStringValue();
|
||||
auto name = getToken().getStringValue();
|
||||
if (name.empty())
|
||||
return (emitError("empty operation name is invalid"), nullptr);
|
||||
if (name.find('\0') != StringRef::npos)
|
||||
|
@ -738,15 +737,6 @@ Operation *OperationParser::parseGenericOperation() {
|
|||
|
||||
OperationState result(srcLocation, name);
|
||||
|
||||
// Lazy load dialects in the context as needed.
|
||||
if (!result.name.getAbstractOperation()) {
|
||||
StringRef dialectName = StringRef(name).split('.').first;
|
||||
if (!getContext()->getLoadedDialect(dialectName) &&
|
||||
getContext()->getOrLoadDialect(dialectName)) {
|
||||
result.name = OperationName(name, getContext());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the operand list.
|
||||
SmallVector<SSAUseInfo, 8> operandInfos;
|
||||
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
|
||||
|
@ -1452,28 +1442,17 @@ private:
|
|||
|
||||
Operation *
|
||||
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
|
||||
llvm::SMLoc opLoc = getToken().getLoc();
|
||||
StringRef opName = getTokenSpelling();
|
||||
auto opLoc = getToken().getLoc();
|
||||
auto opName = getTokenSpelling();
|
||||
|
||||
auto *opDefinition = AbstractOperation::lookup(opName, getContext());
|
||||
if (!opDefinition) {
|
||||
if (opName.contains('.')) {
|
||||
// This op has a dialect, we try to check if we can register it in the
|
||||
// context on the fly.
|
||||
StringRef dialectName = opName.split('.').first;
|
||||
if (!getContext()->getLoadedDialect(dialectName) &&
|
||||
getContext()->getOrLoadDialect(dialectName)) {
|
||||
opDefinition = AbstractOperation::lookup(opName, getContext());
|
||||
}
|
||||
} else {
|
||||
// If the operation name has no namespace prefix we treat it as a standard
|
||||
// operation and prefix it with "std".
|
||||
// TODO: Would it be better to just build a mapping of the registered
|
||||
// operations in the standard dialect?
|
||||
if (getContext()->getOrLoadDialect("std"))
|
||||
opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
|
||||
getContext());
|
||||
}
|
||||
if (!opDefinition && !opName.contains('.')) {
|
||||
// If the operation name has no namespace prefix we treat it as a standard
|
||||
// operation and prefix it with "std".
|
||||
// TODO: Would it be better to just build a mapping of the registered
|
||||
// operations in the standard dialect?
|
||||
opDefinition =
|
||||
AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
|
||||
}
|
||||
|
||||
if (!opDefinition) {
|
||||
|
|
|
@ -290,13 +290,6 @@ OpPassManager::pass_iterator OpPassManager::begin() {
|
|||
}
|
||||
OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); }
|
||||
|
||||
OpPassManager::const_pass_iterator OpPassManager::begin() const {
|
||||
return impl->passes.begin();
|
||||
}
|
||||
OpPassManager::const_pass_iterator OpPassManager::end() const {
|
||||
return impl->passes.end();
|
||||
}
|
||||
|
||||
/// Run all of the passes in this manager over the current operation.
|
||||
LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
|
||||
// Run each of the held passes.
|
||||
|
@ -353,16 +346,6 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
|
|||
::printAsTextualPipeline(impl->passes, os);
|
||||
}
|
||||
|
||||
static void registerDialectsForPipeline(const OpPassManager &pm,
|
||||
DialectRegistry &dialects) {
|
||||
for (const Pass &pass : pm.getPasses())
|
||||
pass.getDependentDialects(dialects);
|
||||
}
|
||||
|
||||
void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
|
||||
registerDialectsForPipeline(*this, dialects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpToOpPassAdaptor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -395,11 +378,6 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
|
|||
mgrs.emplace_back(std::move(mgr));
|
||||
}
|
||||
|
||||
void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
|
||||
for (auto &pm : mgrs)
|
||||
pm.getDependentDialects(dialects);
|
||||
}
|
||||
|
||||
/// Merge the current pass adaptor into given 'rhs'.
|
||||
void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
|
||||
for (auto &pm : mgrs) {
|
||||
|
@ -743,11 +721,6 @@ LogicalResult PassManager::run(ModuleOp module) {
|
|||
// pipeline.
|
||||
getImpl().coalesceAdjacentAdaptorPasses();
|
||||
|
||||
// Register all dialects for the current pipeline.
|
||||
DialectRegistry dependentDialects;
|
||||
getDependentDialects(dependentDialects);
|
||||
dependentDialects.loadAll(module.getContext());
|
||||
|
||||
// Construct an analysis manager for the pipeline.
|
||||
ModuleAnalysisManager am(module, instrumentor.get());
|
||||
|
||||
|
|
|
@ -43,10 +43,6 @@ public:
|
|||
/// Returns the pass managers held by this adaptor.
|
||||
MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
|
||||
|
||||
/// Populate the set of dependent dialects for the passes in the current
|
||||
/// adaptor.
|
||||
void getDependentDialects(DialectRegistry &dialects) const override;
|
||||
|
||||
/// Return the async pass managers held by this parallel adaptor.
|
||||
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
|
||||
return asyncExecutors;
|
||||
|
|
|
@ -81,18 +81,13 @@ static LogicalResult processBuffer(raw_ostream &os,
|
|||
std::unique_ptr<MemoryBuffer> ownedBuffer,
|
||||
bool verifyDiagnostics, bool verifyPasses,
|
||||
bool allowUnregisteredDialects,
|
||||
bool preloadDialectsInContext,
|
||||
const PassPipelineCLParser &passPipeline,
|
||||
DialectRegistry ®istry) {
|
||||
const PassPipelineCLParser &passPipeline) {
|
||||
// Tell sourceMgr about this buffer, which is what the parser will pick up.
|
||||
SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
||||
|
||||
// Parse the input file.
|
||||
MLIRContext context(/*loadAllDialects=*/preloadDialectsInContext);
|
||||
registry.appendTo(context.getDialectRegistry());
|
||||
if (preloadDialectsInContext)
|
||||
registry.loadAll(&context);
|
||||
MLIRContext context;
|
||||
context.allowUnregisteredDialects(allowUnregisteredDialects);
|
||||
context.printOpOnDiagnostic(!verifyDiagnostics);
|
||||
|
||||
|
@ -120,10 +115,9 @@ static LogicalResult processBuffer(raw_ostream &os,
|
|||
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
|
||||
std::unique_ptr<MemoryBuffer> buffer,
|
||||
const PassPipelineCLParser &passPipeline,
|
||||
DialectRegistry ®istry, bool splitInputFile,
|
||||
bool verifyDiagnostics, bool verifyPasses,
|
||||
bool allowUnregisteredDialects,
|
||||
bool preloadDialectsInContext) {
|
||||
bool splitInputFile, bool verifyDiagnostics,
|
||||
bool verifyPasses,
|
||||
bool allowUnregisteredDialects) {
|
||||
// The split-input-file mode is a very specific mode that slices the file
|
||||
// up into small pieces and checks each independently.
|
||||
if (splitInputFile)
|
||||
|
@ -132,19 +126,15 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
|
|||
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
|
||||
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
|
||||
verifyPasses, allowUnregisteredDialects,
|
||||
preloadDialectsInContext, passPipeline,
|
||||
registry);
|
||||
passPipeline);
|
||||
},
|
||||
outputStream);
|
||||
|
||||
return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
|
||||
verifyPasses, allowUnregisteredDialects,
|
||||
preloadDialectsInContext, passPipeline, registry);
|
||||
verifyPasses, allowUnregisteredDialects, passPipeline);
|
||||
}
|
||||
|
||||
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
|
||||
DialectRegistry ®istry,
|
||||
bool preloadDialectsInContext) {
|
||||
LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) {
|
||||
static cl::opt<std::string> inputFilename(
|
||||
cl::Positional, cl::desc("<input file>"), cl::init("-"));
|
||||
|
||||
|
@ -190,19 +180,25 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
|
|||
{
|
||||
llvm::raw_string_ostream os(helpHeader);
|
||||
MLIRContext context;
|
||||
interleaveComma(registry, os, [&](auto ®istryEntry) {
|
||||
StringRef name = registryEntry.first;
|
||||
os << name;
|
||||
interleaveComma(context.getRegisteredDialects(), os, [&](Dialect *dialect) {
|
||||
StringRef name = dialect->getNamespace();
|
||||
// filter the builtin dialect.
|
||||
if (name.empty())
|
||||
os << "<builtin>";
|
||||
else
|
||||
os << name;
|
||||
});
|
||||
}
|
||||
// Parse pass names in main to ensure static initialization completed.
|
||||
cl::ParseCommandLineOptions(argc, argv, helpHeader);
|
||||
|
||||
if (showDialects) {
|
||||
llvm::outs() << "Available Dialects:\n";
|
||||
llvm::outs() << "Registered Dialects:\n";
|
||||
MLIRContext context;
|
||||
interleave(
|
||||
registry, llvm::outs(),
|
||||
[](auto ®istryEntry) { llvm::outs() << registryEntry.first; }, "\n");
|
||||
context.getRegisteredDialects(), llvm::outs(),
|
||||
[](Dialect *dialect) { llvm::outs() << dialect->getNamespace(); },
|
||||
"\n");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -220,9 +216,9 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
|
|||
return failure();
|
||||
}
|
||||
|
||||
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
|
||||
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
|
||||
splitInputFile, verifyDiagnostics, verifyPasses,
|
||||
allowUnregisteredDialects, preloadDialectsInContext)))
|
||||
allowUnregisteredDialects)))
|
||||
return failure();
|
||||
|
||||
// Keep the output file if the invocation of MlirOptMain was successful.
|
||||
|
|
|
@ -15,10 +15,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
Dialect::Dialect(const llvm::Record *def) : def(def) {
|
||||
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
|
||||
dependentDialects.push_back(dialect);
|
||||
}
|
||||
|
||||
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
|
||||
|
||||
|
@ -50,10 +46,6 @@ StringRef Dialect::getDescription() const {
|
|||
return getAsStringOrEmpty(*def, "description");
|
||||
}
|
||||
|
||||
ArrayRef<StringRef> Dialect::getDependentDialects() const {
|
||||
return dependentDialects;
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
|
||||
auto value = def->getValueAsString("extraClassDeclaration");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
|
|
|
@ -69,8 +69,6 @@ Pass::Pass(const llvm::Record *def) : def(def) {
|
|||
options.push_back(PassOption(init));
|
||||
for (auto *init : def->getValueAsListOfDefs("statistics"))
|
||||
statistics.push_back(PassStatistic(init));
|
||||
for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
|
||||
dependentDialects.push_back(dialect);
|
||||
}
|
||||
|
||||
StringRef Pass::getArgument() const {
|
||||
|
@ -90,9 +88,6 @@ StringRef Pass::getDescription() const {
|
|||
StringRef Pass::getConstructor() const {
|
||||
return def->getValueAsString("constructor");
|
||||
}
|
||||
ArrayRef<StringRef> Pass::getDependentDialects() const {
|
||||
return dependentDialects;
|
||||
}
|
||||
|
||||
ArrayRef<PassOption> Pass::getOptions() const { return options; }
|
||||
|
||||
|
|
|
@ -836,7 +836,6 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
|
|||
OwningModuleRef
|
||||
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
|
||||
MLIRContext *context) {
|
||||
context->loadDialect<LLVMDialect>();
|
||||
OwningModuleRef module(ModuleOp::create(
|
||||
FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
|
||||
|
||||
|
|
|
@ -302,7 +302,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
|
|||
: mlirModule(module), llvmModule(std::move(llvmModule)),
|
||||
debugTranslation(
|
||||
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
|
||||
ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()),
|
||||
ompDialect(
|
||||
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
|
||||
typeTranslator(this->llvmModule->getContext()) {
|
||||
assert(satisfiesLLVMModule(mlirModule) &&
|
||||
"mlirModule should honor LLVM's module semantics.");
|
||||
|
@ -943,8 +944,8 @@ ModuleTranslation::lookupValues(ValueRange values) {
|
|||
|
||||
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
|
||||
Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
|
||||
m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
|
||||
auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
|
||||
|
||||
if (auto dataLayoutAttr =
|
||||
m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()))
|
||||
llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
|
||||
|
|
|
@ -12,13 +12,6 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
// Forward declaration from Dialect.h
|
||||
template <typename ConcreteDialect>
|
||||
void registerDialect(DialectRegistry ®istry);
|
||||
|
||||
namespace linalg {
|
||||
class LinalgDialect;
|
||||
} // end namespace linalg
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Transforms/Passes.h.inc"
|
||||
|
|
|
@ -383,7 +383,6 @@ static int printStandardTypes(MlirContext ctx) {
|
|||
int main() {
|
||||
mlirRegisterAllDialects();
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
mlirContextLoadAllDialects(ctx);
|
||||
MlirLocation location = mlirLocationUnknownGet(ctx);
|
||||
|
||||
MlirModule moduleOp = makeAdd(ctx, location);
|
||||
|
|
|
@ -36,18 +36,16 @@ using namespace mlir::edsc;
|
|||
using namespace mlir::edsc::intrinsics;
|
||||
|
||||
static MLIRContext &globalContext() {
|
||||
static thread_local MLIRContext context(/*loadAllDialects=*/false);
|
||||
static thread_local bool initOnce = [&]() {
|
||||
// clang-format off
|
||||
context.loadDialect<AffineDialect,
|
||||
scf::SCFDialect,
|
||||
linalg::LinalgDialect,
|
||||
StandardOpsDialect,
|
||||
vector::VectorDialect>();
|
||||
// clang-format on
|
||||
static bool init_once = []() {
|
||||
registerDialect<AffineDialect>();
|
||||
registerDialect<linalg::LinalgDialect>();
|
||||
registerDialect<scf::SCFDialect>();
|
||||
registerDialect<StandardOpsDialect>();
|
||||
registerDialect<vector::VectorDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)initOnce;
|
||||
(void)init_once;
|
||||
static thread_local MLIRContext context;
|
||||
context.allowUnregisteredDialects();
|
||||
return context;
|
||||
}
|
||||
|
|
|
@ -19,19 +19,18 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
// Load the SDBM dialect
|
||||
static DialectRegistration<SDBMDialect> SDBMRegistration;
|
||||
|
||||
static MLIRContext *ctx() {
|
||||
static thread_local MLIRContext context(/*loadAllDialects=*/false);
|
||||
static thread_local bool once =
|
||||
(context.getOrLoadDialect<SDBMDialect>(), true);
|
||||
(void)once;
|
||||
static thread_local MLIRContext context;
|
||||
return &context;
|
||||
}
|
||||
|
||||
static SDBMDialect *dialect() {
|
||||
static thread_local SDBMDialect *d = nullptr;
|
||||
if (!d) {
|
||||
d = ctx()->getOrLoadDialect<SDBMDialect>();
|
||||
d = ctx()->getRegisteredDialect<SDBMDialect>();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
@ -73,9 +72,6 @@ struct VectorizerTestPass
|
|||
: public PassWrapper<VectorizerTestPass, FunctionPass> {
|
||||
static constexpr auto kTestAffineMapOpName = "test_affine_map";
|
||||
static constexpr auto kTestAffineMapAttrName = "affine_map";
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<vector::VectorDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
void testVectorShapeRatio(llvm::raw_ostream &outs);
|
||||
|
|
|
@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() {
|
|||
auto f = getFunction();
|
||||
llvm::outs() << f.getName() << "\n";
|
||||
|
||||
Dialect *spvDialect = getContext().getLoadedDialect("spv");
|
||||
Dialect *spvDialect = getContext().getRegisteredDialect("spv");
|
||||
|
||||
f.getOperation()->walk([&](Operation *op) {
|
||||
if (op->getDialect() != spvDialect)
|
||||
|
|
|
@ -21,10 +21,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
void mlir::registerTestDialect(DialectRegistry ®istry) {
|
||||
registry.insert<TestDialect>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestDialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -37,8 +37,6 @@ namespace mlir {
|
|||
#define GET_OP_CLASSES
|
||||
#include "TestOps.h.inc"
|
||||
|
||||
void registerTestDialect(DialectRegistry ®istry);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TESTDIALECT_H
|
||||
|
|
|
@ -768,10 +768,6 @@ struct TestTypeConversionProducer
|
|||
|
||||
struct TestTypeConversionDriver
|
||||
: public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TestDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
// Initialize the type converter.
|
||||
TypeConverter converter;
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
@ -20,9 +19,6 @@ using namespace mlir;
|
|||
namespace {
|
||||
struct TestAllReduceLoweringPass
|
||||
: public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect>();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
populateGpuRewritePatterns(&getContext(), patterns);
|
||||
|
|
|
@ -116,10 +116,6 @@ struct TestBufferPlacementPreparationPass
|
|||
patterns->insert<GenericOpConverter>(context, placer, converter);
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext &context = this->getContext();
|
||||
ConversionTarget target(context);
|
||||
|
|
|
@ -13,9 +13,6 @@
|
|||
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/GPU/MemoryPromotion.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
@ -29,10 +26,6 @@ namespace {
|
|||
class TestGpuMemoryPromotionPass
|
||||
: public PassWrapper<TestGpuMemoryPromotionPass,
|
||||
OperationPass<gpu::GPUFuncOp>> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUFuncOp op = getOperation();
|
||||
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -23,9 +22,6 @@ struct TestLinalgHoisting
|
|||
: public PassWrapper<TestLinalgHoisting, FunctionPass> {
|
||||
TestLinalgHoisting() = default;
|
||||
TestLinalgHoisting(const TestLinalgHoisting &pass) {}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -31,16 +30,6 @@ struct TestLinalgTransforms
|
|||
TestLinalgTransforms() = default;
|
||||
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
// clang-format off
|
||||
registry.insert<AffineDialect,
|
||||
scf::SCFDialect,
|
||||
StandardOpsDialect,
|
||||
vector::VectorDialect,
|
||||
gpu::GPUDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void runOnFunction() override;
|
||||
|
||||
Option<bool> testPatterns{*this, "test-patterns",
|
||||
|
|
|
@ -8,9 +8,6 @@
|
|||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
|
@ -131,11 +128,6 @@ struct TestVectorTransferFullPartialSplitPatterns
|
|||
TestVectorTransferFullPartialSplitPatterns() = default;
|
||||
TestVectorTransferFullPartialSplitPatterns(
|
||||
const TestVectorTransferFullPartialSplitPatterns &pass) {}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
Option<bool> useLinalgOps{
|
||||
*this, "use-linalg-copy",
|
||||
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt --show-dialects | FileCheck %s
|
||||
// CHECK: Available Dialects:
|
||||
// CHECK: Registered Dialects:
|
||||
// CHECK: affine
|
||||
// CHECK: gpu
|
||||
// CHECK: linalg
|
||||
|
|
|
@ -1703,7 +1703,7 @@ int main(int argc, char **argv) {
|
|||
if (testEmitIncludeTdHeader)
|
||||
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
|
||||
|
||||
MLIRContext context(/*loadAllDialects=*/false);
|
||||
MLIRContext context;
|
||||
llvm::SourceMgr mgr;
|
||||
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
|
||||
Parser parser(mgr, &context);
|
||||
|
|
|
@ -48,7 +48,6 @@ void registerTestConstantFold();
|
|||
void registerTestConvertGPUKernelToCubinPass();
|
||||
void registerTestConvertGPUKernelToHsacoPass();
|
||||
void registerTestDominancePass();
|
||||
void registerTestDialect(DialectRegistry &);
|
||||
void registerTestExpandTanhPass();
|
||||
void registerTestFunc();
|
||||
void registerTestGpuMemoryPromotionPass();
|
||||
|
@ -131,10 +130,5 @@ int main(int argc, char **argv) {
|
|||
#ifdef MLIR_INCLUDE_TESTS
|
||||
registerTestPasses();
|
||||
#endif
|
||||
DialectRegistry registry;
|
||||
registerAllDialects(registry);
|
||||
registerTestDialect(registry);
|
||||
return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n",
|
||||
registry,
|
||||
/*preloadDialectsInContext=*/false));
|
||||
return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver"));
|
||||
}
|
||||
|
|
|
@ -61,14 +61,11 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
|
|||
///
|
||||
/// {0}: The name of the dialect class.
|
||||
/// {1}: The dialect namespace.
|
||||
/// {2}: initialization code that is emitted in the ctor body before calling
|
||||
/// initialize()
|
||||
static const char *const dialectDeclBeginStr = R"(
|
||||
class {0} : public ::mlir::Dialect {
|
||||
explicit {0}(::mlir::MLIRContext *context)
|
||||
: ::mlir::Dialect(getDialectNamespace(), context,
|
||||
::mlir::TypeID::get<{0}>()) {{
|
||||
{2}
|
||||
initialize();
|
||||
}
|
||||
void initialize();
|
||||
|
@ -77,12 +74,6 @@ public:
|
|||
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
|
||||
)";
|
||||
|
||||
/// Registration for a single dependent dialect: to be inserted in the ctor
|
||||
/// above for each dependent dialect.
|
||||
const char *const dialectRegistrationTemplate = R"(
|
||||
getContext()->getOrLoadDialect<{0}>();
|
||||
)";
|
||||
|
||||
/// The code block for the attribute parser/printer hooks.
|
||||
static const char *const attrParserDecl = R"(
|
||||
/// Parse an attribute registered to this dialect.
|
||||
|
@ -145,18 +136,9 @@ static void emitDialectDecl(Dialect &dialect,
|
|||
iterator_range<DialectFilterIterator> dialectAttrs,
|
||||
iterator_range<DialectFilterIterator> dialectTypes,
|
||||
raw_ostream &os) {
|
||||
/// Build the list of dependent dialects
|
||||
std::string dependentDialectRegistrations;
|
||||
{
|
||||
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
||||
for (StringRef dependentDialect : dialect.getDependentDialects())
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
}
|
||||
// Emit the start of the decl.
|
||||
std::string cppName = dialect.getCppClassName();
|
||||
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
|
||||
dependentDialectRegistrations);
|
||||
os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
|
||||
|
||||
// Check for any attributes/types registered to this dialect. If there are,
|
||||
// add the hooks for parsing/printing.
|
||||
|
|
|
@ -36,7 +36,6 @@ static llvm::cl::opt<std::string>
|
|||
/// {0}: The def name of the pass record.
|
||||
/// {1}: The base class for the pass.
|
||||
/// {2): The command line argument for the pass.
|
||||
/// {3}: The dependent dialects registration.
|
||||
const char *const passDeclBegin = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0}
|
||||
|
@ -64,20 +63,9 @@ public:
|
|||
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
||||
}
|
||||
|
||||
/// Return the dialect that must be loaded in the context before this pass.
|
||||
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
||||
{3}
|
||||
}
|
||||
|
||||
protected:
|
||||
)";
|
||||
|
||||
/// Registration for a single dependent dialect, to be inserted for each
|
||||
/// dependent dialect in the `getDependentDialects` above.
|
||||
const char *const dialectRegistrationTemplate = R"(
|
||||
registry.insert<{0}>();
|
||||
)";
|
||||
|
||||
/// Emit the declarations for each of the pass options.
|
||||
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassOption &opt : pass.getOptions()) {
|
||||
|
@ -106,15 +94,8 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
|||
|
||||
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
StringRef defName = pass.getDef()->getName();
|
||||
std::string dependentDialectRegistrations;
|
||||
{
|
||||
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
|
||||
for (StringRef dependentDialect : pass.getDependentDialects())
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
}
|
||||
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
|
||||
pass.getArgument(), dependentDialectRegistrations);
|
||||
pass.getArgument());
|
||||
emitPassOptionDecls(pass, os);
|
||||
emitPassStatisticDecls(pass, os);
|
||||
os << "};\n";
|
||||
|
|
|
@ -88,8 +88,7 @@ int main(int argc, char **argv) {
|
|||
// Processes the memory buffer with a new MLIRContext.
|
||||
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
|
||||
raw_ostream &os) {
|
||||
MLIRContext context(false);
|
||||
registerAllDialects(&context);
|
||||
MLIRContext context;
|
||||
context.allowUnregisteredDialects();
|
||||
context.printOpOnDiagnostic(!verifyDiagnostics);
|
||||
llvm::SourceMgr sourceMgr;
|
||||
|
|
|
@ -17,6 +17,9 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::quant;
|
||||
|
||||
// Load the quant dialect
|
||||
static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
|
||||
|
||||
namespace {
|
||||
|
||||
// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
|
||||
|
@ -75,8 +78,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
|
|||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
||||
MLIRContext ctx(/*loadAllDialects=*/false);
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
MLIRContext ctx;
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
|
@ -93,8 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
|||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
||||
MLIRContext ctx(/*loadAllDialects=*/false);
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
MLIRContext ctx;
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
|
@ -118,8 +119,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
|||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
||||
MLIRContext ctx(/*loadAllDialects=*/false);
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
MLIRContext ctx;
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
|
@ -143,8 +143,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
|||
}
|
||||
|
||||
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
||||
MLIRContext ctx(/*loadAllDialects=*/false);
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
MLIRContext ctx;
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
|
|
|
@ -38,8 +38,7 @@ using ::testing::StrEq;
|
|||
/// diagnostic checking utilities.
|
||||
class DeserializationTest : public ::testing::Test {
|
||||
protected:
|
||||
DeserializationTest() : context(/*loadAllDialects=*/false) {
|
||||
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
|
||||
DeserializationTest() {
|
||||
// Register a diagnostic handler to capture the diagnostic so that we can
|
||||
// check it later.
|
||||
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
|
||||
|
|
|
@ -36,10 +36,7 @@ using namespace mlir;
|
|||
|
||||
class SerializationTest : public ::testing::Test {
|
||||
protected:
|
||||
SerializationTest() : context(/*loadAllDialects=*/false) {
|
||||
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
|
||||
createModuleOp();
|
||||
}
|
||||
SerializationTest() { createModuleOp(); }
|
||||
|
||||
void createModuleOp() {
|
||||
OpBuilder builder(&context);
|
||||
|
|
|
@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
|
|||
|
||||
namespace {
|
||||
TEST(DenseSplatTest, BoolSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
||||
|
||||
|
@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
|
|||
TEST(DenseSplatTest, LargeBoolSplat) {
|
||||
constexpr int64_t boolCount = 56;
|
||||
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
||||
|
||||
|
@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, BoolNonSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
||||
|
||||
|
@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
|
|||
|
||||
TEST(DenseSplatTest, OddIntSplat) {
|
||||
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
constexpr size_t intWidth = 19;
|
||||
IntegerType intTy = IntegerType::get(intWidth, &context);
|
||||
APInt value(intWidth, 10);
|
||||
|
@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, Int32Splat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(32, &context);
|
||||
int value = 64;
|
||||
|
||||
|
@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, IntAttrSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(85, &context);
|
||||
Attribute value = IntegerAttr::get(intTy, 109);
|
||||
|
||||
|
@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, F32Splat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getF32(&context);
|
||||
float value = 10.0;
|
||||
|
||||
|
@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, F64Splat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getF64(&context);
|
||||
double value = 10.0;
|
||||
|
||||
|
@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, FloatAttrSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getF32(&context);
|
||||
Attribute value = FloatAttr::get(floatTy, 10.0);
|
||||
|
||||
|
@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, BF16Splat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
FloatType floatTy = FloatType::getBF16(&context);
|
||||
Attribute value = FloatAttr::get(floatTy, 10.0);
|
||||
|
||||
|
@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, StringSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
||||
StringRef value = "test-string";
|
||||
|
@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseSplatTest, StringAttrSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
||||
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
||||
|
@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
|
|||
}
|
||||
|
||||
TEST(DenseComplexTest, ComplexFloatSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
||||
std::complex<float> value(10.0, 15.0);
|
||||
testSplat(complexType, value);
|
||||
}
|
||||
|
||||
TEST(DenseComplexTest, ComplexIntSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
||||
std::complex<int64_t> value(10, 15);
|
||||
testSplat(complexType, value);
|
||||
}
|
||||
|
||||
TEST(DenseComplexTest, ComplexAPFloatSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
||||
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
|
||||
testSplat(complexType, value);
|
||||
}
|
||||
|
||||
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
||||
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
||||
testSplat(complexType, value);
|
||||
|
|
|
@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect {
|
|||
};
|
||||
|
||||
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
|
||||
// Registering a dialect with the same namespace twice should result in a
|
||||
// failure.
|
||||
context.loadDialect<TestDialect>();
|
||||
ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
|
||||
context.getOrCreateDialect<TestDialect>();
|
||||
ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
|
|
@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context,
|
|||
|
||||
namespace {
|
||||
TEST(OperandStorageTest, NonResizable) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
|
@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) {
|
|||
}
|
||||
|
||||
TEST(OperandStorageTest, Resizable) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
|
@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) {
|
|||
}
|
||||
|
||||
TEST(OperandStorageTest, RangeReplace) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
|
@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) {
|
|||
}
|
||||
|
||||
TEST(OperandStorageTest, MutableRange) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Operation *useOp =
|
||||
|
|
|
@ -29,7 +29,7 @@ struct OpSpecificAnalysis {
|
|||
};
|
||||
|
||||
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
|
||||
// Test fine grain invalidation of the module analysis manager.
|
||||
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
|
||||
|
@ -50,7 +50,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
|
|||
}
|
||||
|
||||
TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
// Create a function and a module.
|
||||
|
@ -79,7 +79,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
|
|||
}
|
||||
|
||||
TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
// Create a function and a module.
|
||||
|
@ -122,7 +122,7 @@ struct CustomInvalidatingAnalysis {
|
|||
};
|
||||
|
||||
TEST(AnalysisManagerTest, CustomInvalidation) {
|
||||
MLIRContext context(false);
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
// Create a function and a module.
|
||||
|
|
|
@ -17,17 +17,18 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Load the SDBM dialect.
|
||||
static DialectRegistration<SDBMDialect> SDBMRegistration;
|
||||
|
||||
static MLIRContext *ctx() {
|
||||
static thread_local MLIRContext context(false);
|
||||
context.getOrLoadDialect<SDBMDialect>();
|
||||
static thread_local MLIRContext context;
|
||||
return &context;
|
||||
}
|
||||
|
||||
static SDBMDialect *dialect() {
|
||||
static thread_local SDBMDialect *d = nullptr;
|
||||
if (!d) {
|
||||
d = ctx()->getOrLoadDialect<SDBMDialect>();
|
||||
d = ctx()->getRegisteredDialect<SDBMDialect>();
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
|
|
@ -25,16 +25,11 @@ namespace mlir {
|
|||
// Test Fixture
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static MLIRContext &getContext() {
|
||||
static MLIRContext ctx(false);
|
||||
ctx.getOrLoadDialect<TestDialect>();
|
||||
return ctx;
|
||||
}
|
||||
/// Test fixture for providing basic utilities for testing.
|
||||
class OpBuildGenTest : public ::testing::Test {
|
||||
protected:
|
||||
OpBuildGenTest()
|
||||
: ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()),
|
||||
: ctx{}, builder(&ctx), loc(builder.getUnknownLoc()),
|
||||
i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()),
|
||||
cstI32(builder.create<TableGenConstant>(loc, i32Ty)),
|
||||
cstF32(builder.create<TableGenConstant>(loc, f32Ty)),
|
||||
|
@ -91,7 +86,7 @@ protected:
|
|||
}
|
||||
|
||||
protected:
|
||||
MLIRContext &ctx;
|
||||
MLIRContext ctx;
|
||||
OpBuilder builder;
|
||||
Location loc;
|
||||
Type i32Ty;
|
||||
|
|
|
@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
|
|||
/// Validates that test::TestStruct::classof correctly identifies a valid
|
||||
/// test::TestStruct.
|
||||
TEST(StructsGenTest, ClassofTrue) {
|
||||
mlir::MLIRContext context(false);
|
||||
mlir::MLIRContext context;
|
||||
auto structAttr = getTestStruct(&context);
|
||||
ASSERT_TRUE(test::TestStruct::classof(structAttr));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue