forked from mindspore-Ecosystem/mindspore
add_no_elimilate_for_comm_op
This commit is contained in:
parent
0ea683a90e
commit
56b31fce42
|
@ -155,6 +155,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('index', 0)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
@ -228,6 +229,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('mean_flag', False)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -345,6 +347,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
|||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||
|
@ -419,6 +422,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
self.add_prim_attr('rank_size', self.rank_size)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('fusion', 0)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if self.rank_size == 0:
|
||||
|
@ -472,6 +476,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
|||
self.op = op
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.group_size != 0:
|
||||
|
@ -548,6 +553,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
check_hcom_group_valid(group)
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
@ -594,6 +600,7 @@ class AllSwap(PrimitiveWithCheck):
|
|||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def __check__(self, tensor_in, send_size, recv_size):
|
||||
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
||||
|
@ -638,6 +645,7 @@ class NeighborExchange(Primitive):
|
|||
self.recv_shapes = recv_shapes
|
||||
self.send_shapes = send_shapes
|
||||
self.recv_type = recv_type
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
@ -677,6 +685,7 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
self.split_dim = split_dim
|
||||
self.concat_dim = concat_dim
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
self.add_prim_attr('no_elimilate', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
rank_size = get_group_size(_get_group(self.group))
|
||||
|
|
|
@ -107,6 +107,9 @@ class _AutoParallelContext:
|
|||
raise TypeError("The type of pipeline_stage_num must be int.")
|
||||
if stages < 1:
|
||||
raise ValueError("pipeline_stage_num can't be less than 1.")
|
||||
backend = context.get_context("device_target")
|
||||
if backend == "GPU" and stages > 1:
|
||||
raise RuntimeError("Now GPU don't support pipeline parallel.")
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_pipeline_stage_split_num(stages)
|
||||
|
||||
|
|
|
@ -320,6 +320,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
|||
} else if (name == "index") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "0");
|
||||
} else if (name == "no_elimilate") {
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||
ASSERT_EQ(converted_ret->ToString(), "true");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Test failed";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue