fix bcewithlogitsloss op error in pynative

This commit is contained in:
chujinjin 2021-06-07 19:38:36 +08:00
parent 3ee381fb67
commit 90feb6a6d2
2 changed files with 13 additions and 0 deletions

View File

@ -189,6 +189,16 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@ -558,6 +568,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpOptimize(kernel_graph);
SelectKernel(kernel_graph);
RunOpHardwareOptimize(kernel_graph);
StartKernelRT();

View File

@ -66,6 +66,8 @@ class GPUSession : public SessionBasic {
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);