!44544 BCEWithLogitsLossFusion

Merge pull request !44544 from nomindcarry/master
This commit is contained in:
i-robot 2022-11-29 11:42:18 +00:00 committed by Gitee
commit b4b932bf9e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 2 additions and 1 deletions

View File

@ -431,6 +431,7 @@ void RunOpOptimize(const KernelGraphPtr &kernel_graph) {
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -57,7 +57,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
MS_EXCEPTION_IF_NULL(axis_node);
axis_node->set_abstract(axis_tensor->ToAbstract());
axis_node = kernel_graph->NewValueNode(axis_node);
kernel_graph->AddValueNode(axis_node);
kernel_graph->AddValueNodeToGraph(axis_node);
if (reduction == "sum") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode, axis_node};
} else if (reduction == "mean") {