!17932 fix bcewithlogitsloss op error in pynative

From: @chujinjin
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-06-08 09:12:38 +08:00 committed by Gitee
commit 1475a3e121
2 changed files with 13 additions and 0 deletions

View File

@ -189,6 +189,16 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@ -558,6 +568,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpOptimize(kernel_graph);
SelectKernel(kernel_graph);
RunOpHardwareOptimize(kernel_graph);
StartKernelRT();

View File

@ -66,6 +66,8 @@ class GPUSession : public SessionBasic {
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);