forked from mindspore-Ecosystem/mindspore
Add defer_inline flag for on-grading primal graph.
This commit is contained in:
parent
6e390ff119
commit
8d36b00426
|
@ -644,6 +644,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
forward_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
FuncGraphPtr grad_fg = nullptr;
|
||||
{
|
||||
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
|
|
|
@ -338,7 +338,6 @@ class TrainOneStepCell(Cell):
|
|||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
|
|
Loading…
Reference in New Issue