add_no_elimilate_for_comm_op

This commit is contained in:
lichenever 2021-09-02 17:25:55 +08:00
parent 0ea683a90e
commit 56b31fce42
3 changed files with 15 additions and 0 deletions

View File

@ -155,6 +155,7 @@ class AllReduce(PrimitiveWithInfer):
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
return x_shape
@ -228,6 +229,7 @@ class AllGather(PrimitiveWithInfer):
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('mean_flag', False)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -345,6 +347,7 @@ class _HostAllGather(PrimitiveWithInfer):
validator.check_value_type("rank_id", r, (int,), self.name)
self.group_size = len(group)
self.add_prim_attr('group', group)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
@ -419,6 +422,7 @@ class ReduceScatter(PrimitiveWithInfer):
self.add_prim_attr('rank_size', self.rank_size)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
if self.rank_size == 0:
@ -472,6 +476,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
self.op = op
self.group_size = len(group)
self.add_prim_attr('group', group)
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
if x_shape[0] % self.group_size != 0:
@ -548,6 +553,7 @@ class Broadcast(PrimitiveWithInfer):
validator.check_value_type('group', _get_group(group), (str,), self.name)
check_hcom_group_valid(group)
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
return x_shape
@ -594,6 +600,7 @@ class AllSwap(PrimitiveWithCheck):
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
def __check__(self, tensor_in, send_size, recv_size):
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
@ -638,6 +645,7 @@ class NeighborExchange(Primitive):
self.recv_shapes = recv_shapes
self.send_shapes = send_shapes
self.recv_type = recv_type
self.add_prim_attr('no_elimilate', True)
def __call__(self, tensor):
raise NotImplementedError
@ -677,6 +685,7 @@ class AlltoAll(PrimitiveWithInfer):
self.split_dim = split_dim
self.concat_dim = concat_dim
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('no_elimilate', True)
def infer_shape(self, x_shape):
rank_size = get_group_size(_get_group(self.group))

View File

@ -107,6 +107,9 @@ class _AutoParallelContext:
raise TypeError("The type of pipeline_stage_num must be int.")
if stages < 1:
raise ValueError("pipeline_stage_num can't be less than 1.")
backend = context.get_context("device_target")
if backend == "GPU" and stages > 1:
raise RuntimeError("Now GPU don't support pipeline parallel.")
self.check_context_handle()
self._context_handle.set_pipeline_stage_split_num(stages)

View File

@ -320,6 +320,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else if (name == "no_elimilate") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "true");
} else {
MS_LOG(EXCEPTION) << "Test failed";
}