forked from mindspore-Ecosystem/mindspore
!12055 compile graph for merging parameter slice only once
From: @yangzhenzhang Reviewed-by: @stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
8fc5962418
|
@ -90,6 +90,7 @@ class Cell(Cell_):
|
|||
self._phase = 'train'
|
||||
self._parameter_layout_dict = {}
|
||||
self._parallel_parameter_name_list = ()
|
||||
self._parallel_parameter_merge_net_dict = {}
|
||||
self._create_time = int(time.time() * 1e9)
|
||||
self.phase_prefix = ""
|
||||
self.parameter_broadcast_done = False
|
||||
|
@ -224,6 +225,16 @@ class Cell(Cell_):
|
|||
raise TypeError("'parallel_parameter_name_list' must be list type.")
|
||||
self._parallel_parameter_name_list = value
|
||||
|
||||
@property
|
||||
def parallel_parameter_merge_net_dict(self):
|
||||
return self._parallel_parameter_merge_net_dict
|
||||
|
||||
@parallel_parameter_merge_net_dict.setter
|
||||
def parallel_parameter_merge_net_dict(self, value):
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("'parallel_parameter_merge_net_dict' must be dict type.")
|
||||
self._parallel_parameter_merge_net_dict = value
|
||||
|
||||
def get_func_graph_proto(self):
|
||||
"""Return graph binary proto."""
|
||||
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
|
||||
|
@ -382,46 +393,63 @@ class Cell(Cell_):
|
|||
self._add_attr(key, value)
|
||||
self._attr_synced = True
|
||||
|
||||
def _set_attr_for_parameter(self, name, value):
|
||||
"""Set attr for parameter."""
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
if name in self.__dict__:
|
||||
if self.__dict__[name] is not None:
|
||||
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
|
||||
del self.__dict__[name]
|
||||
if cells and name in cells:
|
||||
raise TypeError("The type of value should be Cell, but got Parameter.")
|
||||
self.insert_param_to_cell(name, value)
|
||||
|
||||
def _set_attr_for_parameter_tuple(self, name, value):
|
||||
"""Set attr for parameter tuple."""
|
||||
params = self.__dict__.get('_params')
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
for item in value:
|
||||
self.insert_param_to_cell(item.name, item, check_name=False)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
if name in params:
|
||||
del params[name]
|
||||
params_list[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _set_attr_for_cell(self, name, value):
|
||||
"""Set attr for cell."""
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
if cells is None:
|
||||
raise AttributeError("Can not assign cells before Cell.__init__() call.")
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
if params and name in params:
|
||||
raise TypeError("The type of value should be Parameter, but got Cell.")
|
||||
if self._auto_prefix:
|
||||
value.update_parameters_name(name + '.')
|
||||
cells[name] = value
|
||||
if hasattr(self, '_cell_init_args'):
|
||||
self.cell_init_args += str({name: value})
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
params_list = self.__dict__.get('_params_list')
|
||||
tensor_list = self.__dict__.get('_tensor_list')
|
||||
if isinstance(value, Parameter):
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
if name in self.__dict__:
|
||||
if self.__dict__[name] is not None:
|
||||
raise TypeError("The type of value should not be Parameter or Cell, but got Parameter.")
|
||||
del self.__dict__[name]
|
||||
if cells and name in cells:
|
||||
raise TypeError("The type of value should be Cell, but got Parameter.")
|
||||
self.insert_param_to_cell(name, value)
|
||||
self._set_attr_for_parameter(name, value)
|
||||
elif isinstance(value, ParameterTuple):
|
||||
if params is None:
|
||||
raise AttributeError("Can not assign params before Cell.__init__() call.")
|
||||
for item in value:
|
||||
self.insert_param_to_cell(item.name, item, check_name=False)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
if name in params:
|
||||
del params[name]
|
||||
params_list[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
elif isinstance(value, Cell):
|
||||
if cells is None:
|
||||
raise AttributeError("Can not assign cells before Cell.__init__() call.")
|
||||
if name in self.__dict__:
|
||||
del self.__dict__[name]
|
||||
if params and name in params:
|
||||
raise TypeError("The type of value should be Parameter, but got Cell.")
|
||||
if self._auto_prefix:
|
||||
value.update_parameters_name(name + '.')
|
||||
cells[name] = value
|
||||
if hasattr(self, '_cell_init_args'):
|
||||
self.cell_init_args += str({name: value})
|
||||
self._set_attr_for_cell(name, value)
|
||||
elif params and name in params:
|
||||
if isinstance(value, Tensor) and self._params[name] is not None:
|
||||
self._params[name].set_data(value)
|
||||
|
|
|
@ -455,6 +455,12 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
uniform_split = layout[4]
|
||||
opt_shard_group = layout[5]
|
||||
|
||||
allgather_net = None
|
||||
if param_name in net.parallel_parameter_merge_net_dict:
|
||||
allgather_net = net.parallel_parameter_merge_net_dict[param_name]
|
||||
else:
|
||||
logger.info("need to create allgather net for %s", param_name)
|
||||
|
||||
if integrated_save:
|
||||
if uniform_split == 0:
|
||||
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
|
||||
|
@ -462,19 +468,25 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|||
# pipeline parallel need to be supported here later
|
||||
for dim in tensor_map:
|
||||
if dim != -1:
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, True)
|
||||
else:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
if allgather_net is None:
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, True)
|
||||
else:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
if field_size:
|
||||
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
|
||||
return _reshape_param_data(param_data, dev_mat, tensor_map)
|
||||
if opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
if allgather_net is None:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
elif opt_shard_group:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
if allgather_net is None:
|
||||
allgather_net = get_allgather_cell(opt_shard_group, False)
|
||||
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
||||
param_data = allgather_net(param_data)
|
||||
return param_data
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, mask=0):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2)
|
||||
self.weight = Parameter(weight, "w1")
|
||||
self.mul2 = P.Mul()
|
||||
self.weight2 = Parameter(w2, "w2")
|
||||
self.begin = begin
|
||||
self.end = end
|
||||
self.strides = strides
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.strided_slice(
|
||||
self.weight, self.begin, self.end, self.strides)
|
||||
out = self.mul(x, out)
|
||||
out = self.mul2(out, self.weight2)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
|
||||
|
||||
|
||||
def clean_all_ckpt_files(folder_path):
|
||||
if os.path.exists(folder_path):
|
||||
for file_name in os.listdir(folder_path):
|
||||
if file_name.endswith('.ckpt') or file_name.endswith('.meta'):
|
||||
os.remove(os.path.join(folder_path, file_name))
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
context.set_context(save_graphs=True)
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
dataset = Dataset(_x, _b)
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
model = Model(net, optimizer=opt)
|
||||
ckpt_config = CheckpointConfig(keep_checkpoint_max=1)
|
||||
ckpt_path = "./parallel_ckpt"
|
||||
ckpt_cb = ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[ckpt_cb])
|
||||
assert len(model._train_network.parallel_parameter_merge_net_dict) == 4
|
||||
clean_all_ckpt_files(ckpt_path)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_stridedslice_parameter():
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 4, 1), (1, 4, 2))
|
||||
strategy2 = ((1, 4, 2),)
|
||||
net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
|
||||
strategy1, strategy2)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue