!35444 print_reduce pass in dynamic_shape gpu

Merge pull request !35444 from TuDouNi/pass_support_dynamic
This commit is contained in:
i-robot 2022-06-07 05:50:03 +00:00 committed by Gitee
commit 057ed3b773
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 23 additions and 22 deletions

View File

@ -268,34 +268,35 @@ void GPUDeviceContext::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
// In the dynamic shape scene, the infershape stage needs to call the primitive infer function.
// When the fusion operator generates a new primitive, but there
// is no corresponding primitive infer function, an error will occur.
// Therefore, this kind of scene does not support dynamic shape.
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic shape skip fusion";
return;
MS_LOG(INFO) << "Dynamic shape skip some fusion pass";
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
} else {
pm->AddPass(std::make_shared<opt::MatMulBiasAddFusion>());
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::AllToAllFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
pm->AddPass(std::make_shared<opt::NeighborExchangeV2Fusion>());
pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradFusion>());
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::MatMulBiasAddFusion>());
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::AllToAllFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
pm->AddPass(std::make_shared<opt::PrintReduceFusion>("print_reduce"));
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
pm->AddPass(std::make_shared<opt::NeighborExchangeV2Fusion>());
pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);
graph->SetExecOrderByDefault();