!49776 Add depend node to avoid isolated node loss.
Merge pull request !49776 from Margaret_wangrui/parallel_isolated
This commit is contained in:
commit
f9aead948e
|
@ -207,14 +207,12 @@ def get_bprop_mirror_micro_step_operator(self):
|
|||
z = F.depend(z, dout)
|
||||
real_grad = all_reduce(z)
|
||||
real_grad = F.tensor_mul(real_grad, scale)
|
||||
assign(z, real_grad)
|
||||
assign_out = z
|
||||
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
|
||||
else:
|
||||
if issubclass_(F.typeof(dout), mstype.tensor):
|
||||
z = F.depend(z, dout)
|
||||
real_grad = all_reduce(z)
|
||||
assign(z, real_grad)
|
||||
assign_out = z
|
||||
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
|
||||
if opt_shard:
|
||||
return (real_grad, cast(out_tensor, dtype(z)))
|
||||
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -305,6 +305,7 @@ class _MiniStepAllGather(PrimitiveWithInfer):
|
|||
self.grad_accumulation_step = grad_accumulation_step
|
||||
self.mean_flag = mean_flag
|
||||
self.add_prim_attr('order_enforce_skip', True)
|
||||
self.add_prim_attr('side_effect_backprop_mem', True)
|
||||
|
||||
def infer_shape(self, x_shape, z_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -1030,6 +1031,7 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
|
|||
self.mean_flag = mean_flag
|
||||
self.grad_accumulation_step = grad_accumulation_step
|
||||
self.add_prim_attr('order_enforce_skip', True)
|
||||
self.add_prim_attr('side_effect_backprop_mem', True)
|
||||
|
||||
def infer_shape(self, x_shape, z_shape):
|
||||
return x_shape
|
||||
|
@ -1186,6 +1188,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
|
|||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
self.add_prim_attr('order_enforce_skip', True)
|
||||
self.add_prim_attr('side_effect_backprop_mem', True)
|
||||
|
||||
def infer_shape(self, x_shape, z_shape):
|
||||
return x_shape
|
||||
|
|
Loading…
Reference in New Issue