From 90feb6a6d22b89be0941a68f6e767e1846866f3e Mon Sep 17 00:00:00 2001 From: chujinjin Date: Mon, 7 Jun 2021 19:38:36 +0800 Subject: [PATCH] fix bcewithlogitsloss op error in pynative --- mindspore/ccsrc/backend/session/gpu_session.cc | 11 +++++++++++ mindspore/ccsrc/backend/session/gpu_session.h | 2 ++ 2 files changed, 13 insertions(+) diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 8047565bef4..9d002fb64c5 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -189,6 +189,16 @@ void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_gra kernel_graph->SetExecOrderByDefault(); } +void GPUSession::RunOpOptimize(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + void GPUSession::RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -558,6 +568,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap // Prepare the graph auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(kernel_graph); + RunOpOptimize(kernel_graph); SelectKernel(kernel_graph); RunOpHardwareOptimize(kernel_graph); StartKernelRT(); diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 3ce2b34eca3..d6dad45db2a 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -66,6 +66,8 @@ class GPUSession : public SessionBasic { void HardwareOptimize(const std::shared_ptr &kernel_graph); + void RunOpOptimize(const std::shared_ptr &kernel_graph); + void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph); void GraphKernelOptimize(const std::shared_ptr &kernel_graph);