forked from mindspore-Ecosystem/mindspore
!9832 expose_allgather_fusion_to_users
From: @gong_zi_yan Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b67aaf6773
|
@ -942,6 +942,29 @@ bool IsCastBeforMirror(const CNodePtr &node, size_t index) {
|
|||
return (type_id != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
|
||||
MS_EXCEPTION_IF_NULL(comm_node);
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
|
||||
MS_LOG(WARNING) << "The mirror of Receive does not support fusion type now.";
|
||||
return;
|
||||
}
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto attrs = prim->attrs();
|
||||
auto param_info = param->param_info();
|
||||
if (!param_info) {
|
||||
MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
|
||||
return;
|
||||
}
|
||||
int32_t fusion_type = param_info->comm_fusion();
|
||||
attrs[FUSION] = MakeValue<int64_t>(fusion_type);
|
||||
prim->SetAttrs(attrs);
|
||||
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
|
||||
}
|
||||
|
||||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t node_size = node->inputs().size();
|
||||
|
@ -1006,11 +1029,19 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr pre_node = cnode->input(1);
|
||||
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
|
||||
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
// pipeline mirror would not be set, which should be supported later
|
||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||
}
|
||||
} else {
|
||||
for (auto &op : backward_op) {
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
InsertNode(op, node, index, pre_node, func_graph, instance_name);
|
||||
auto comm_op = node->input(index)->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
// pipeline mirror would not be set, which should be supported later
|
||||
AddCommOpFusionType(comm_op, param_node_pair.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1342,7 +1373,8 @@ std::pair<AnfNodePtr, int64_t> FindSubGraph(const FuncGraphPtr &graph, const Anf
|
|||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res, const AnfNodePtr ¶meter) {
|
||||
static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int> &res,
|
||||
const AnfNodePtr ¶meter) {
|
||||
Operator op = CreateAllGatherOp(group);
|
||||
MS_EXCEPTION_IF_NULL(res.first);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
|
@ -1360,11 +1392,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
|
|||
}
|
||||
// add fusion flag
|
||||
MS_EXCEPTION_IF_NULL(allgather);
|
||||
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
// enable fusion flag later when it's supported in backend
|
||||
attrs["fusion"] = MakeValue<int64_t>(1);
|
||||
prim->SetAttrs(attrs);
|
||||
AddCommOpFusionType(allgather, parameter);
|
||||
}
|
||||
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
|
@ -1419,6 +1447,9 @@ std::string SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNod
|
|||
if (!ParameterRequireGrad(parameter)) {
|
||||
// only trainable parameters need parallel optimizer
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not trainable parameter.";
|
||||
} else if (parameter->cast<ParameterPtr>()->param_info() &&
|
||||
!parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
|
||||
MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " does not need weight shard.";
|
||||
} else if (tensor_layout.GenerateOptShardSliceShape() == Status::SUCCESS) {
|
||||
// get a totally shard tensor slice shape if the weight is repeated on devices
|
||||
// and the shape of the first dimension could be divided
|
||||
|
|
|
@ -29,6 +29,9 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
|||
.def_property("init_in_server", &ParamInfo::init_in_server, &ParamInfo::set_init_in_server)
|
||||
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel,
|
||||
&ParamInfo::set_layerwise_parallel)
|
||||
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
|
||||
&ParamInfo::set_parallel_optimizer)
|
||||
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
|
||||
.def(py::pickle(
|
||||
[](const ParamInfo &p) { // __getstate__
|
||||
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
|
||||
|
|
|
@ -75,8 +75,11 @@ class Parameter(MetaTensor_):
|
|||
default_input (Union[Tensor, MetaTensor, Number]): Parameter data, to be set initialized.
|
||||
name (str): Name of the child parameter. Default: None.
|
||||
requires_grad (bool): True if the parameter requires gradient. Default: True.
|
||||
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in parallel mode,
|
||||
layerwise_parallel (bool): When layerwise_parallel is true in data parallel mode,
|
||||
broadcast and gradients communication would not be applied to parameters. Default: False.
|
||||
parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
|
||||
mode. It works only when enable parallel optimizer in `mindspore.context.set_auto_parallel_context()`.
|
||||
Default: True.
|
||||
|
||||
Example:
|
||||
>>> from mindspore import Parameter, Tensor
|
||||
|
@ -132,19 +135,21 @@ class Parameter(MetaTensor_):
|
|||
return (
|
||||
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
|
||||
|
||||
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False):
|
||||
def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
|
||||
self._param_info = ParamInfo()
|
||||
self.init_in_server = False
|
||||
self.cache_enable = False
|
||||
self.name = name
|
||||
self.requires_grad = requires_grad
|
||||
self.layerwise_parallel = layerwise_parallel
|
||||
self.parallel_optimizer = parallel_optimizer
|
||||
# this flag for tensor copy data.
|
||||
self.init_flag = False
|
||||
# this flag is for ge variable copy data.
|
||||
self._is_init = False
|
||||
self._inited_param = None
|
||||
self._sliced = False
|
||||
self.comm_fusion = 1
|
||||
self.is_param_ps = False
|
||||
self._cast_type = None
|
||||
self._unique = False
|
||||
|
@ -210,7 +215,6 @@ class Parameter(MetaTensor_):
|
|||
raise RuntimeError("Must complete following two steps before calling set_param_ps: \
|
||||
1. set_ps_context(enable_ps=True) \
|
||||
2. export MS_ROLE environment variable.")
|
||||
|
||||
if init_in_server and (not self.name.endswith("embedding_table")):
|
||||
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "
|
||||
"sparse operator support initialization in server.".format(self.name))
|
||||
|
@ -218,7 +222,6 @@ class Parameter(MetaTensor_):
|
|||
self.init_in_server = init_in_server
|
||||
self._param_info.init_in_server = init_in_server
|
||||
|
||||
|
||||
@property
|
||||
def inited_param(self):
|
||||
"""
|
||||
|
@ -273,6 +276,16 @@ class Parameter(MetaTensor_):
|
|||
def sliced(self, sliced_):
|
||||
self._sliced = sliced_
|
||||
|
||||
@property
|
||||
def comm_fusion(self):
|
||||
"""Get the fusion type for communication operators corresponding to this parameter."""
|
||||
return self._param_info.comm_fusion
|
||||
|
||||
@comm_fusion.setter
|
||||
def comm_fusion(self, comm_fusion_):
|
||||
"""Set the fusion type for communication operators corresponding to this parameter."""
|
||||
self._param_info.comm_fusion = comm_fusion_
|
||||
|
||||
@property
|
||||
def unique(self):
|
||||
"""whether the parameter is already unique or not."""
|
||||
|
@ -338,6 +351,17 @@ class Parameter(MetaTensor_):
|
|||
raise TypeError("`layerwise_parallel` parameter must be bool type")
|
||||
self._param_info.layerwise_parallel = value
|
||||
|
||||
@property
|
||||
def parallel_optimizer(self):
|
||||
"""Return whether the parameter requires weight shard for parallel optimizer."""
|
||||
return self._param_info.parallel_optimizer
|
||||
|
||||
@parallel_optimizer.setter
|
||||
def parallel_optimizer(self, value=True):
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("`parallel_optimizer` parameter must be bool type")
|
||||
self._param_info.parallel_optimizer = value
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
"""Return whether the parameter requires gradient."""
|
||||
|
|
|
@ -75,6 +75,12 @@ class ParamInfo {
|
|||
return clone;
|
||||
}
|
||||
|
||||
int32_t comm_fusion() const { return fusion_type_; }
|
||||
void set_comm_fusion(int32_t fusion_type) { fusion_type_ = fusion_type; }
|
||||
|
||||
bool parallel_optimizer() const { return parallel_optimizer_; }
|
||||
void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
|
||||
|
||||
private:
|
||||
std::string name_{"Parameter"};
|
||||
bool requires_grad_{true};
|
||||
|
@ -84,6 +90,8 @@ class ParamInfo {
|
|||
bool cloned_{false};
|
||||
std::vector<int32_t> be_cloned_index_;
|
||||
int32_t cloned_index_{0};
|
||||
int32_t fusion_type_{1};
|
||||
bool parallel_optimizer_{true};
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_
|
||||
|
|
|
@ -1087,6 +1087,12 @@ class Cell(Cell_):
|
|||
for param in params:
|
||||
param.set_param_ps(init_in_server)
|
||||
|
||||
def set_comm_fusion(self, fusion_type, recurse=True):
|
||||
Validator.check_is_int(fusion_type)
|
||||
for param in self.trainable_params(recurse):
|
||||
param.comm_fusion = fusion_type
|
||||
return self
|
||||
|
||||
|
||||
class GraphKernel(Cell):
|
||||
"""
|
||||
|
|
|
@ -127,7 +127,7 @@ def get_bprop_all_gather(self):
|
|||
instance_name = "grad_" + self.instance_name
|
||||
reduce_scatter.set_prim_instance_name(instance_name)
|
||||
else:
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1)
|
||||
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
||||
if self.instance_name:
|
||||
instance_name = "grad_" + self.instance_name
|
||||
all_reduce.set_prim_instance_name(instance_name)
|
||||
|
@ -242,9 +242,7 @@ def get_bprop_mirror_operator(self):
|
|||
mul = P.Mul()
|
||||
cast = P.Cast()
|
||||
|
||||
fusion = 1
|
||||
if hasattr(self, 'fusion'):
|
||||
fusion = self.fusion
|
||||
fusion = self.get_attr_dict()["fusion"]
|
||||
all_reduce.add_prim_attr("fusion", fusion)
|
||||
if hasattr(self, 'parameter'):
|
||||
parameter = self.parameter
|
||||
|
|
|
@ -555,6 +555,7 @@ class _MirrorOperator(PrimitiveWithInfer):
|
|||
self.group = group
|
||||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
self.add_prim_attr("fusion", 1)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from tests.dataset_mock import MindData
|
||||
import pytest
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
|
@ -125,6 +126,7 @@ def train_common(net):
|
|||
return allreduce_fusion_dict
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion_parameters():
|
||||
cost_model_context.reset_cost_model_context()
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
|
||||
|
@ -181,6 +183,7 @@ def test_allreduce_fusion_parameters():
|
|||
assert computation_time_parameter == 0.1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion1():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
|
@ -205,6 +208,7 @@ def test_allreduce_fusion1():
|
|||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
# reset_cost_model_context is called, the default value of costmodel_allreduce_fusion_times is 0, step_allreduce_fusion
|
||||
# is bypassed.
|
||||
def test_allreduce_fusion2():
|
||||
|
@ -220,6 +224,7 @@ def test_allreduce_fusion2():
|
|||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion3():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=3)
|
||||
|
@ -248,6 +253,7 @@ def test_allreduce_fusion3():
|
|||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion4():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
|
||||
|
@ -277,6 +283,7 @@ def test_allreduce_fusion4():
|
|||
cost_model_context.reset_cost_model_context()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="depreciated feature")
|
||||
def test_allreduce_fusion5():
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
|
||||
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.1)
|
||||
|
|
|
@ -66,15 +66,30 @@ class Net2(nn.Cell):
|
|||
return x - y
|
||||
|
||||
|
||||
def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
|
||||
class Net3(nn.Cell):
|
||||
"""Net definition"""
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super(Net3, self).__init__()
|
||||
self.fc1 = P.MatMul().shard(strategy=strategy1)
|
||||
self.fc2 = P.MatMul().shard(strategy=strategy2)
|
||||
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
|
||||
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.fc1(x, self.p1)
|
||||
x = self.fc2(x, self.p2)
|
||||
return x - y
|
||||
|
||||
|
||||
def auto_parallel_compile_net(mode, dev_num, net, strategy1=None, strategy2=None):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode=mode, device_num=dev_num, enable_parallel_optimizer=True)
|
||||
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
|
||||
label = Tensor(np.zeros([32, 16]).astype(np.float32))
|
||||
net = Net2(strategy1, strategy2)
|
||||
net = net(strategy1, strategy2)
|
||||
net = _VirtualDatasetCell(net)
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_network = TrainOneStepCell(net, optimizer)
|
||||
train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
|
||||
train_network.set_auto_parallel()
|
||||
train_network.set_train()
|
||||
_executor.compile(train_network, inputs, label, phase="train", auto_parallel_mode=True)
|
||||
|
@ -83,18 +98,18 @@ def auto_parallel_compile_net(mode, dev_num, strategy1=None, strategy2=None):
|
|||
|
||||
|
||||
def test_auto_parallel_momentum_1():
|
||||
auto_parallel_compile_net("auto_parallel", 8)
|
||||
auto_parallel_compile_net("auto_parallel", 8, Net2)
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_2():
|
||||
# data parallel case
|
||||
auto_parallel_compile_net("auto_parallel", 8, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
||||
auto_parallel_compile_net("auto_parallel", 8, Net2, ((8, 1), (1, 1)), ((8, 1), (1, 1)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_3():
|
||||
# hybrid parallel case
|
||||
# weight1 could not be shard and weight2 is repeated
|
||||
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
||||
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
||||
param_dict = train_network.parameter_layout_dict
|
||||
# validate opt_shard_group
|
||||
assert not param_dict["weight1"][5]
|
||||
|
@ -104,7 +119,16 @@ def test_auto_parallel_momentum_3():
|
|||
def test_auto_parallel_momentum_4():
|
||||
# hybrid parallel cases
|
||||
# devices are repeatedly used
|
||||
auto_parallel_compile_net("semi_auto_parallel", 32, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
|
||||
auto_parallel_compile_net("semi_auto_parallel", 32, Net2, ((4, 4), (4, 1)), ((4, 4), (4, 2)))
|
||||
|
||||
|
||||
def test_auto_parallel_momentum_5():
|
||||
# test parallel optimizer filter
|
||||
train_network = auto_parallel_compile_net("semi_auto_parallel", 32, Net3, ((4, 8), (8, 1)), ((4, 4), (4, 2)))
|
||||
param_dict = train_network.parameter_layout_dict
|
||||
# validate opt_shard_group
|
||||
assert not param_dict["weight1"][5]
|
||||
assert not param_dict["weight2"][5]
|
||||
|
||||
|
||||
def test_AdamWeightDecay():
|
||||
|
|
Loading…
Reference in New Issue