diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index e2b4d55aada..28cddd61119 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -49,6 +49,9 @@ namespace mindspore { namespace parallel { const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; +// g_RefMap, for CNode B input i is a RefKey[Parameter C], +// it will be one item in map with key: C, and value: (B, i) +static std::map> g_RefMap; void SetCommunicationOpGroupLabel(std::vector new_node_input) { if (new_node_input.empty()) { @@ -1085,11 +1088,19 @@ std::vector ExtractShape(const CNodePtr& node) { std::vector all_inputs = node->inputs(); std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - for (auto& input : node_inputs) { + size_t inputs_size = all_inputs.size(); + for (size_t i = 1; i < inputs_size; ++i) { Shapes input_shapes; + AnfNodePtr input = all_inputs[i]; if (IsValueNode(input)) { auto func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + std::pair node_pair = std::make_pair(node, SizeToInt(i)); + g_RefMap[parameters[0]] = node_pair; input_shapes = GetRefKeyNodeShape(input, func_graph); } else if (IsValueNode(input) || input->isa() || input->isa()) { input_shapes = GetNodeShape(input); @@ -1205,14 +1216,20 @@ void CoverSliceShape(const FuncGraphPtr& root) { auto parameters = root->parameters(); for (auto& parameter : parameters) { MS_EXCEPTION_IF_NULL(parameter->Shape()); + auto iter = g_RefMap.find(parameter); + if (iter != g_RefMap.end()) { + SetParallelShape(parameter, g_RefMap[parameter]); + continue; + } std::pair res = FindSubGraph(root, parameter); if (res.first == nullptr) { MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; } else { SetParallelShape(parameter, res); - MS_LOG(DEBUG) << "parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); + MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); } } + g_RefMap.clear(); } bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) { diff --git a/tests/ut/python/parallel/test_arithmetic.py b/tests/ut/python/parallel/test_arithmetic.py index 2c7eabc8f2f..4c34c0371ef 100644 --- a/tests/ut/python/parallel/test_arithmetic.py +++ b/tests/ut/python/parallel/test_arithmetic.py @@ -13,14 +13,13 @@ # limitations under the License. import numpy as np -from mindspore import context +import mindspore as ms +from mindspore import Parameter, Tensor, context import mindspore.nn as nn from mindspore.ops import operations as P -from mindspore import Tensor -from tests.ut.python.ops.test_math_ops import VirtualLoss -import mindspore as ms -from mindspore.common.api import _executor from mindspore.ops import composite as C +from mindspore.common.api import _executor +from tests.ut.python.ops.test_math_ops import VirtualLoss class NetWithLoss(nn.Cell): @@ -470,3 +469,30 @@ def test_matmul_floordiv_broadcast2(): y = Tensor(np.ones([32, 1]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32) _executor.compile(net, x, y, b) + + +def test_assign_sub(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.assign_sub = P.AssignSub() + self.mul = P.Mul() + self.mul_weight = Parameter(Tensor(np.full([128, 32], + 0.5, dtype=np.float32)), + name="mul_weight") + self.assignsub_weight = Parameter(Tensor(np.full([128, 32], + 1.1, dtype=np.float32)), + name="assignsub_weight") + + def construct(self, x, y, z): + out = self.mul(x, self.mul_weight) + out = self.assign_sub(self.assignsub_weight, out) + return out + + context.set_auto_parallel_context(device_num=64, global_rank=15) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net = GradWrap(NetWithLoss(Net())) + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([128, 32]), dtype=ms.float32) + z = Tensor(np.ones([128, 32]), dtype=ms.float32) + _executor.compile(net, x, y, z)