forked from mindspore-Ecosystem/mindspore
!17932 fix bcewithlogitsloss op error in pynative
From: @chujinjin Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
1475a3e121
|
@ -189,6 +189,16 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
|
|||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
|
@ -558,6 +568,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
|
|||
// Prepare the graph
|
||||
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
RunOpOptimize(kernel_graph);
|
||||
SelectKernel(kernel_graph);
|
||||
RunOpHardwareOptimize(kernel_graph);
|
||||
StartKernelRT();
|
||||
|
|
|
@ -66,6 +66,8 @@ class GPUSession : public SessionBasic {
|
|||
|
||||
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
||||
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
|
||||
|
|
Loading…
Reference in New Issue