!12055 compile graph for merging parameter slice only once

From: @yangzhenzhang
Reviewed-by: @stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-02-04 14:30:47 +08:00 committed by Gitee
commit 8fc5962418
3 changed files with 184 additions and 39 deletions

View File

@ -90,6 +90,7 @@ class Cell(Cell_):
self._phase = 'train'
self._parameter_layout_dict = {}
self._parallel_parameter_name_list = ()
self._parallel_parameter_merge_net_dict = {}
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
self.parameter_broadcast_done = False
@ -224,6 +225,16 @@ class Cell(Cell_):
raise TypeError("'parallel_parameter_name_list' must be list type.")
self._parallel_parameter_name_list = value
@property
def parallel_parameter_merge_net_dict(self):
return self._parallel_parameter_merge_net_dict
@parallel_parameter_merge_net_dict.setter
def parallel_parameter_merge_net_dict(self, value):
if not isinstance(value, dict):
raise TypeError("'parallel_parameter_merge_net_dict' must be dict type.")
self._parallel_parameter_merge_net_dict = value
def get_func_graph_proto(self):
"""Return graph binary proto."""
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
@ -382,46 +393,63 @@ class Cell(Cell_):
self._add_attr(key, value)
self._attr_synced = True
def _set_attr_for_parameter(self, name, value):
"""Set attr for parameter."""
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
if name in self.__dict__:
if self.__dict__[name] is not None:
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
del self.__dict__[name]
if cells and name in cells:
raise TypeError("The type of value should be Cell, but got Parameter.")
self.insert_param_to_cell(name, value)
def _set_attr_for_parameter_tuple(self, name, value):
"""Set attr for parameter tuple."""
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
for item in value:
self.insert_param_to_cell(item.name, item, check_name=False)
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
def _set_attr_for_cell(self, name, value):
"""Set attr for cell."""
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
if cells is None:
raise AttributeError("Can not assign cells before Cell.__init__() call.")
if name in self.__dict__:
del self.__dict__[name]
if params and name in params:
raise TypeError("The type of value should be Parameter, but got Cell.")
if self._auto_prefix:
value.update_parameters_name(name + '.')
cells[name] = value
if hasattr(self, '_cell_init_args'):
self.cell_init_args += str({name: value})
def __setattr__(self, name, value):
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
params_list = self.__dict__.get('_params_list')
tensor_list = self.__dict__.get('_tensor_list')
if isinstance(value, Parameter):
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
if name in self.__dict__:
if self.__dict__[name] is not None:
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
del self.__dict__[name]
if cells and name in cells:
raise TypeError("The type of value should be Cell, but got Parameter.")
self.insert_param_to_cell(name, value)
self._set_attr_for_parameter(name, value)
elif isinstance(value, ParameterTuple):
if params is None:
raise AttributeError("Can not assign params before Cell.__init__() call.")
for item in value:
self.insert_param_to_cell(item.name, item, check_name=False)
if context.get_context("mode") == context.PYNATIVE_MODE:
if name in self.__dict__:
del self.__dict__[name]
if name in params:
del params[name]
params_list[name] = value
else:
object.__setattr__(self, name, value)
self._set_attr_for_parameter_tuple(name, value)
elif isinstance(value, Cell):
if cells is None:
raise AttributeError("Can not assign cells before Cell.__init__() call.")
if name in self.__dict__:
del self.__dict__[name]
if params and name in params:
raise TypeError("The type of value should be Parameter, but got Cell.")
if self._auto_prefix:
value.update_parameters_name(name + '.')
cells[name] = value
if hasattr(self, '_cell_init_args'):
self.cell_init_args += str({name: value})
self._set_attr_for_cell(name, value)
elif params and name in params:
if isinstance(value, Tensor) and self._params[name] is not None:
self._params[name].set_data(value)

View File

@ -455,6 +455,12 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
uniform_split = layout[4]
opt_shard_group = layout[5]
allgather_net = None
if param_name in net.parallel_parameter_merge_net_dict:
allgather_net = net.parallel_parameter_merge_net_dict[param_name]
else:
logger.info("need to create allgather net for %s", param_name)
if integrated_save:
if uniform_split == 0:
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
@ -462,19 +468,25 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
# pipeline parallel need to be supported here later
for dim in tensor_map:
if dim != -1:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
if allgather_net is None:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
if allgather_net is None:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
elif opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
if allgather_net is None:
allgather_net = get_allgather_cell(opt_shard_group, False)
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
param_data = allgather_net(param_data)
return param_data

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from tests.dataset_mock import MindData
class Dataset(MindData):
def __init__(self, predict, label, length=3):
super(Dataset, self).__init__(size=length)
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label
def reset(self):
self.index = 0
class Net(Cell):
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, mask=0):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2)
self.weight = Parameter(weight, "w1")
self.mul2 = P.Mul()
self.weight2 = Parameter(w2, "w2")
self.begin = begin
self.end = end
self.strides = strides
def construct(self, x, b):
out = self.strided_slice(
self.weight, self.begin, self.end, self.strides)
out = self.mul(x, out)
out = self.mul2(out, self.weight2)
return out
_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32)
_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
def clean_all_ckpt_files(folder_path):
if os.path.exists(folder_path):
for file_name in os.listdir(folder_path):
if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
os.remove(os.path.join(folder_path, file_name))
def compile_net(net):
context.set_context(save_graphs=True)
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
dataset = Dataset(_x, _b)
opt = Momentum(net.trainable_params(), learning_rate, momentum)
model = Model(net, optimizer=opt)
ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
ckpt_path = "./parallel_ckpt"
ckpt_cb = ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[ckpt_cb])
assert len(model._train_network.parallel_parameter_merge_net_dict) == 4
clean_all_ckpt_files(ckpt_path)
context.reset_auto_parallel_context()
def test_stridedslice_parameter():
context.set_auto_parallel_context(
parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 1), (1, 4, 2))
strategy2 = ((1, 4, 2),)
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
strategy1, strategy2)
compile_net(net)