!49776 Add depend node to avoid isolated node loss.

Merge pull request !49776 from Margaret_wangrui/parallel_isolated
This commit is contained in:
i-robot 2023-03-08 03:09:19 +00:00 committed by Gitee
commit f9aead948e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 5 deletions

View File

@ -207,14 +207,12 @@ def get_bprop_mirror_micro_step_operator(self):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale)
assign(z, real_grad)
assign_out = z
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
else:
if issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign(z, real_grad)
assign_out = z
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
if opt_shard:
return (real_grad, cast(out_tensor, dtype(z)))
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -305,6 +305,7 @@ class _MiniStepAllGather(PrimitiveWithInfer):
self.grad_accumulation_step = grad_accumulation_step
self.mean_flag = mean_flag
self.add_prim_attr('order_enforce_skip', True)
self.add_prim_attr('side_effect_backprop_mem', True)
def infer_shape(self, x_shape, z_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -1030,6 +1031,7 @@ class _MirrorMiniStepOperator(PrimitiveWithInfer):
self.mean_flag = mean_flag
self.grad_accumulation_step = grad_accumulation_step
self.add_prim_attr('order_enforce_skip', True)
self.add_prim_attr('side_effect_backprop_mem', True)
def infer_shape(self, x_shape, z_shape):
return x_shape
@ -1186,6 +1188,7 @@ class _MirrorMicroStepOperator(PrimitiveWithInfer):
self.dev_num = dev_num
self.mean_flag = mean_flag
self.add_prim_attr('order_enforce_skip', True)
self.add_prim_attr('side_effect_backprop_mem', True)
def infer_shape(self, x_shape, z_shape):
return x_shape