forked from mindspore-Ecosystem/mindspore
!44544 BCEWithLogitsLossFusion
Merge pull request !44544 from nomindcarry/master
This commit is contained in:
commit
b4b932bf9e
|
@ -431,6 +431,7 @@ void RunOpOptimize(const KernelGraphPtr &kernel_graph) {
|
|||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
|
||||
pm->AddPass(std::make_shared<opt::InsertCastGPU>("insert_cast_gpu"));
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -57,7 +57,7 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
axis_node->set_abstract(axis_tensor->ToAbstract());
|
||||
axis_node = kernel_graph->NewValueNode(axis_node);
|
||||
kernel_graph->AddValueNode(axis_node);
|
||||
kernel_graph->AddValueNodeToGraph(axis_node);
|
||||
if (reduction == "sum") {
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode, axis_node};
|
||||
} else if (reduction == "mean") {
|
||||
|
|
Loading…
Reference in New Issue