From 56b31fce428e6a2746fc3a2a37c403d7ca74656a Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 2 Sep 2021 17:25:55 +0800 Subject: [PATCH] add_no_elimilate_for_comm_op --- mindspore/ops/operations/comm_ops.py | 9 +++++++++ mindspore/parallel/_auto_parallel_context.py | 3 +++ tests/ut/cpp/parallel/step_parallel_test.cc | 3 +++ 3 files changed, 15 insertions(+) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 70cbc5ea0a8..59e8b93f421 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -155,6 +155,7 @@ class AllReduce(PrimitiveWithInfer): self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) self.add_prim_attr('index', 0) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): return x_shape @@ -228,6 +229,7 @@ class AllGather(PrimitiveWithInfer): self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) self.add_prim_attr('mean_flag', False) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) @@ -345,6 +347,7 @@ class _HostAllGather(PrimitiveWithInfer): validator.check_value_type("rank_id", r, (int,), self.name) self.group_size = len(group) self.add_prim_attr('group', group) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) @@ -419,6 +422,7 @@ class ReduceScatter(PrimitiveWithInfer): self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): if self.rank_size == 0: @@ -472,6 +476,7 @@ class _HostReduceScatter(PrimitiveWithInfer): self.op = op self.group_size = len(group) self.add_prim_attr('group', group) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): if x_shape[0] % self.group_size != 0: @@ -548,6 +553,7 @@ class Broadcast(PrimitiveWithInfer): validator.check_value_type('group', _get_group(group), (str,), self.name) check_hcom_group_valid(group) self.add_prim_attr('group', _get_group(group)) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): return x_shape @@ -594,6 +600,7 @@ class AllSwap(PrimitiveWithCheck): validator.check_value_type('group', _get_group(group), (str,), self.name) self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) self.add_prim_attr('group', _get_group(group)) + self.add_prim_attr('no_elimilate', True) def __check__(self, tensor_in, send_size, recv_size): validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name) @@ -638,6 +645,7 @@ class NeighborExchange(Primitive): self.recv_shapes = recv_shapes self.send_shapes = send_shapes self.recv_type = recv_type + self.add_prim_attr('no_elimilate', True) def __call__(self, tensor): raise NotImplementedError @@ -677,6 +685,7 @@ class AlltoAll(PrimitiveWithInfer): self.split_dim = split_dim self.concat_dim = concat_dim self.add_prim_attr('group', _get_group(group)) + self.add_prim_attr('no_elimilate', True) def infer_shape(self, x_shape): rank_size = get_group_size(_get_group(self.group)) diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 2ead99237b2..d15b39b3018 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -107,6 +107,9 @@ class _AutoParallelContext: raise TypeError("The type of pipeline_stage_num must be int.") if stages < 1: raise ValueError("pipeline_stage_num can't be less than 1.") + backend = context.get_context("device_target") + if backend == "GPU" and stages > 1: + raise RuntimeError("Now GPU don't support pipeline parallel.") self.check_context_handle() self._context_handle.set_pipeline_stage_split_num(stages) diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 4a63d541563..468f1b400ab 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -320,6 +320,9 @@ TEST_F(TestStepParallel, CreatOpInstance) { } else if (name == "index") { parse::ConvertData(py::cast(item.second), &converted_ret); ASSERT_EQ(converted_ret->ToString(), "0"); + } else if (name == "no_elimilate") { + parse::ConvertData(py::cast(item.second), &converted_ret); + ASSERT_EQ(converted_ret->ToString(), "true"); } else { MS_LOG(EXCEPTION) << "Test failed"; }