forked from mindspore-Ecosystem/mindspore
add_no_elimilate_for_comm_op
This commit is contained in:
parent
0ea683a90e
commit
56b31fce42
|
@ -155,6 +155,7 @@ class AllReduce(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('fusion', 0)
|
self.add_prim_attr('fusion', 0)
|
||||||
self.add_prim_attr('index', 0)
|
self.add_prim_attr('index', 0)
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
@ -228,6 +229,7 @@ class AllGather(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('fusion', 0)
|
self.add_prim_attr('fusion', 0)
|
||||||
self.add_prim_attr('mean_flag', False)
|
self.add_prim_attr('mean_flag', False)
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||||
|
@ -345,6 +347,7 @@ class _HostAllGather(PrimitiveWithInfer):
|
||||||
validator.check_value_type("rank_id", r, (int,), self.name)
|
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||||
self.group_size = len(group)
|
self.group_size = len(group)
|
||||||
self.add_prim_attr('group', group)
|
self.add_prim_attr('group', group)
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
validator.check_positive_int(len(x_shape), "x shape", self.name)
|
||||||
|
@ -419,6 +422,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('rank_size', self.rank_size)
|
self.add_prim_attr('rank_size', self.rank_size)
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
self.add_prim_attr('fusion', 0)
|
self.add_prim_attr('fusion', 0)
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
if self.rank_size == 0:
|
if self.rank_size == 0:
|
||||||
|
@ -472,6 +476,7 @@ class _HostReduceScatter(PrimitiveWithInfer):
|
||||||
self.op = op
|
self.op = op
|
||||||
self.group_size = len(group)
|
self.group_size = len(group)
|
||||||
self.add_prim_attr('group', group)
|
self.add_prim_attr('group', group)
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
if x_shape[0] % self.group_size != 0:
|
if x_shape[0] % self.group_size != 0:
|
||||||
|
@ -548,6 +553,7 @@ class Broadcast(PrimitiveWithInfer):
|
||||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||||
check_hcom_group_valid(group)
|
check_hcom_group_valid(group)
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
return x_shape
|
return x_shape
|
||||||
|
@ -594,6 +600,7 @@ class AllSwap(PrimitiveWithCheck):
|
||||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||||
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def __check__(self, tensor_in, send_size, recv_size):
|
def __check__(self, tensor_in, send_size, recv_size):
|
||||||
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
|
||||||
|
@ -638,6 +645,7 @@ class NeighborExchange(Primitive):
|
||||||
self.recv_shapes = recv_shapes
|
self.recv_shapes = recv_shapes
|
||||||
self.send_shapes = send_shapes
|
self.send_shapes = send_shapes
|
||||||
self.recv_type = recv_type
|
self.recv_type = recv_type
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def __call__(self, tensor):
|
def __call__(self, tensor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -677,6 +685,7 @@ class AlltoAll(PrimitiveWithInfer):
|
||||||
self.split_dim = split_dim
|
self.split_dim = split_dim
|
||||||
self.concat_dim = concat_dim
|
self.concat_dim = concat_dim
|
||||||
self.add_prim_attr('group', _get_group(group))
|
self.add_prim_attr('group', _get_group(group))
|
||||||
|
self.add_prim_attr('no_elimilate', True)
|
||||||
|
|
||||||
def infer_shape(self, x_shape):
|
def infer_shape(self, x_shape):
|
||||||
rank_size = get_group_size(_get_group(self.group))
|
rank_size = get_group_size(_get_group(self.group))
|
||||||
|
|
|
@ -107,6 +107,9 @@ class _AutoParallelContext:
|
||||||
raise TypeError("The type of pipeline_stage_num must be int.")
|
raise TypeError("The type of pipeline_stage_num must be int.")
|
||||||
if stages < 1:
|
if stages < 1:
|
||||||
raise ValueError("pipeline_stage_num can't be less than 1.")
|
raise ValueError("pipeline_stage_num can't be less than 1.")
|
||||||
|
backend = context.get_context("device_target")
|
||||||
|
if backend == "GPU" and stages > 1:
|
||||||
|
raise RuntimeError("Now GPU don't support pipeline parallel.")
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
self._context_handle.set_pipeline_stage_split_num(stages)
|
self._context_handle.set_pipeline_stage_split_num(stages)
|
||||||
|
|
||||||
|
|
|
@ -320,6 +320,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
|
||||||
} else if (name == "index") {
|
} else if (name == "index") {
|
||||||
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||||
ASSERT_EQ(converted_ret->ToString(), "0");
|
ASSERT_EQ(converted_ret->ToString(), "0");
|
||||||
|
} else if (name == "no_elimilate") {
|
||||||
|
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
|
||||||
|
ASSERT_EQ(converted_ret->ToString(), "true");
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Test failed";
|
MS_LOG(EXCEPTION) << "Test failed";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue