From 05c003ae6b1c6b89338c1b67e4c61ad9396ba6a0 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Thu, 3 Sep 2020 20:27:10 +0800 Subject: [PATCH] origin/semi_auto_parallel_reshape_parameter_has_another_user --- .../ccsrc/frontend/parallel/step_parallel.cc | 28 +++++++++++++++++++ .../ccsrc/frontend/parallel/step_parallel.h | 2 ++ .../parallel/test_auto_parallel_reshape.py | 22 +++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index b69b83f2576..fb53fedb01b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1645,8 +1645,36 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n return nullptr; } +std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node) { + FuncGraphManagerPtr manager = node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) { + continue; + } + if (IsParallelCareNode(use_apply) && use_apply->has_user_data()) { + auto layout = GetInputLayoutFromCNode(node_pair); + return std::make_shared(layout); + } + } + return nullptr; +} + std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { // Create DataParallel tensor layout for parameter(support WideDeep). + auto next_layout = FindParameterNextLayout(node); + if (next_layout != nullptr) { + return next_layout; + } CheckGlobalDeviceManager(); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); TensorLayout input_tensor_layout; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index ca049d1704d..a9a4d941b25 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair FindParameterNextLayout(const AnfNodePtr &node); + ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index bde20845621..a54660bf6cb 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -292,3 +292,25 @@ def test_reshape_auto_6(): context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() _executor.compile(net, x, y) + +def test_reshape_auto_7(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") + + def construct(self, x): + weight = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(weight, self.mul_weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 28]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x)