diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index a344607362f..71c8bb2614d 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -618,19 +618,11 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node for (auto &replace_input : replace_graph->first) { auto pre_node = node->input(IntToSize(replace_input.second)); manager->SetEdge(replace_input.first, 1, pre_node); - auto replace_input_cnode = replace_input.first->cast(); - MS_EXCEPTION_IF_NULL(replace_input_cnode); - (void)replace_input_cnode->set_operator_info(node->operator_info()); - replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node } // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called auto replace_output = replace_graph->second; MS_EXCEPTION_IF_NULL(replace_output); (void)manager->Replace(node, replace_output); - CNodePtr replace_output_cnode = replace_graph->second->cast(); - MS_EXCEPTION_IF_NULL(replace_output_cnode); - (void)replace_output_cnode->set_operator_info(node->operator_info()); - replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node } int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { @@ -1994,14 +1986,27 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector(node)) { StepSplitTensor(node, manager); } } + + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + if (distribute_operator == nullptr) { + continue; + } + // StepReplace + StepReplace(distribute_operator, cnode); + } + } } namespace { diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 6252116efee..0a327fd0e8f 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -83,12 +83,6 @@ class _DatasetIter: self.dataset = dataset dataset_types, dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes - # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to - # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number - # times the batch dimension of tensors for run - if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - device_num = _get_device_num() - self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num) def __iter__(self): self.ind = 0 @@ -119,6 +113,12 @@ class _DatasetIterMSLoopSink(_DatasetIter): def __init__(self, dataset): super(_DatasetIterMSLoopSink, self).__init__(dataset) self.loop_count = self.get_loop_count(dataset) + # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to + # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number + # times the batch dimension of tensors for run. Now only support LoopSink. + if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + device_num = _get_device_num() + self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) def op(): return tuple() diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index a9e01b3c24f..793f3b91c43 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -170,4 +170,3 @@ def test_gatherv2_auto1(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) -