Add defer_inline flag for on-grading primal graph.

This commit is contained in:
Zhang Qinghua 2021-03-16 20:22:57 +08:00
parent 6e390ff119
commit 8d36b00426
2 changed files with 1 additions and 1 deletions

View File

@ -644,6 +644,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
FuncGraphPtr grad_fg = nullptr;
{
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));

View File

@ -338,7 +338,6 @@ class TrainOneStepCell(Cell):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)