!9401 fix broadcast bug

From: @jinyaohui
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-03 09:58:28 +08:00 committed by Gitee
commit 23856d6117
8 changed files with 185 additions and 18 deletions

View File

@ -337,7 +337,6 @@ class _Executor:
self.is_init = False
self._executor = Executor_.get_instance()
self.compile_cache = {}
self.phase_prefix = ""
def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
input_indexs, phase='dataset'):
@ -383,7 +382,12 @@ class _Executor:
"""Build broadcast graph."""
from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
_broadcast_net = _BroadCastCell(broadcast_params_dict.values())
if not broadcast_params_dict:
broadcast_params_dict = {}
broadcast_params = []
for param in broadcast_params_dict.values():
broadcast_params.append(Tensor(param.asnumpy()))
_broadcast_net = _BroadCastCell(broadcast_params)
_broadcast_net.phase = broadcast_phase
broadcasted_params = _broadcast_net()
for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
@ -440,11 +444,11 @@ class _Executor:
if not hasattr(obj, "inputs_to_attr"):
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
self.phase_prefix = str(key[1])
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
else:
phase = self.phase_prefix + phase + '.' + str(obj.create_time)
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
@ -518,9 +522,8 @@ class _Executor:
for param_name, param in obj.parameters_broadcast_dict().items():
if param_name not in auto_split_param_names:
broadcast_params_dict[param_name] = param
broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time)
broadcast_phase = "_broadcast_subgraph"
self._build_broadcast_graph(broadcast_params_dict, broadcast_phase)
self.compile_cache[phase] = broadcast_phase
return phase, True
@ -529,15 +532,15 @@ class _Executor:
return self._executor.updata_param_node_default_input(phase, new_param)
def _get_shard_strategy(self, obj):
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
return self._executor.get_strategy(real_phase)
def _get_num_parallel_ops(self, obj):
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
return self._executor.get_num_parallel_ops(real_phase)
def _get_allreduce_fusion(self, obj):
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
return self._executor.get_allreduce_fusion(real_phase)
def has_compiled(self, phase='predict'):
@ -581,7 +584,7 @@ class _Executor:
if phase == 'save':
return self._executor((), phase + '.' + str(obj.create_time))
phase_real = self.phase_prefix + phase + '.' + str(obj.create_time)
phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time)
if self.has_compiled(phase_real):
return self._exec_pip(obj, *args, phase=phase_real)
raise KeyError('{} graph is not exist.'.format(phase_real))
@ -589,10 +592,10 @@ class _Executor:
def del_net_res(self, net_id):
self._executor.del_net_res(net_id)
def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False):
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
"""Get graph proto from pipeline."""
if use_prefix:
exec_id = self.phase_prefix + exec_id
exec_id = obj.phase_prefix + exec_id
if self._executor.has_compiled(exec_id) is False:
return None
return self._executor.get_func_graph_proto(exec_id, ir_type)

View File

@ -570,8 +570,8 @@ def set_context(**kwargs):
>>> context.set_context(reserve_class_name_in_scope=True)
>>> context.set_context(variable_memory_max_size="6GB")
>>> context.set_context(mode=context.GRAPH_MODE,
>>> device_target="Ascend",device_id=0, save_graphs=True,
>>> save_graphs_path="/mindspore")
... device_target="Ascend",device_id=0, save_graphs=True,
... save_graphs_path="/mindspore")
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb")

View File

@ -87,6 +87,7 @@ class Cell(Cell_):
self._phase = 'train'
self._parameter_layout_dict = {}
self._create_time = int(time.time() * 1e9)
self.phase_prefix = ""
init_backend()
# call gc to release GE session resources used by non-used cell objects
@ -237,7 +238,7 @@ class Cell(Cell_):
def get_func_graph_proto(self):
"""Return graph binary proto."""
return _executor._get_func_graph_proto(self.phase + "." + str(self.create_time), "anf_ir", True)
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
def __getattr__(self, name):
if '_params' in self.__dict__:

View File

@ -556,7 +556,7 @@ def _export(net, file_name, file_format, *inputs):
elif file_format == 'ONNX': # file_format is 'ONNX'
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id)
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
file_name += ".onnx"
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
@ -564,7 +564,7 @@ def _export(net, file_name, file_format, *inputs):
elif file_format == 'MINDIR': # file_format is 'MINDIR'
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir')
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
file_name += ".mindir"
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)

22
tests/st/broadcast/env.sh Normal file
View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
LOCAL_HIAI=/usr/local/HiAI
export TBE_IMPL_PATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/impl/
export LD_LIBRARY_PATH=${LOCAL_HIAI}/runtime/lib64/:${LD_LIBRARY_PATH}
export PATH=${LOCAL_HIAI}/runtime/ccec_compiler/bin/:${PATH}
export PYTHONPATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/:${PYTHONPATH}
export DEVICE_MEMORY_CAPACITY=1073741824000
export NOT_FULLY_USE_DEVICES=off

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import mindspore.communication.management as distributedTool
import mindspore.nn as nn
from mindspore import context
from mindspore.nn.metrics import Accuracy
from mindspore.train import Model
from mindspore.train.callback import LossMonitor, TimeMonitor
from model_zoo.official.cv.lenet.src.dataset import create_dataset
from model_zoo.official.cv.lenet.src.lenet import LeNet5
np.set_printoptions(threshold=np.inf)
device_num = 2
device_id = int(os.getenv('DEVICE_ID'))
rank_id = 0
def setup_module():
global device_num
global rank_id
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
distributedTool.init()
rank_id = distributedTool.get_rank()
device_num = distributedTool.get_group_size()
context.set_auto_parallel_context(device_num=device_num, global_rank=device_id, parameter_broadcast=True)
def teardown_module():
distributedTool.release()
def test_all_trains():
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
model.train(1, ds_train, callbacks=[time_cb, LossMonitor()])

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
set -e
BASE_PATH=$(
cd "$(dirname $0)"
pwd
)
CONFIG_PATH=/home/workspace/mindspore_config
export DEVICE_NUM=8
export RANK_SIZE=$DEVICE_NUM
source ${BASE_PATH}/env.sh
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
process_pid=()
for ((i = 0; i < $DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/lenet_broadcast${i}
mkdir ${BASE_PATH}/lenet_broadcast${i}
cp -r ${BASE_PATH}/lenet_broadcast_auto_parallel.py ${BASE_PATH}/lenet_broadcast${i}/
cd ${BASE_PATH}/lenet_broadcast${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
echo "start training for device $i"
env >env$i.log
pytest -s -v lenet_broadcast_auto_parallel.py >test_lenet_auto_parallel_broadcast_8p_log$i.log 2>&1 &
process_pid[${i}]=$(echo $!)
done
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
wait ${process_pid[i]}
status=$(echo $?)
if [ "${status}" != "0" ]; then
echo "[ERROR] test_broadcast_auto_parallel failed. status: ${status}"
exit 1
else
echo "[INFO] test_broadcast_auto_parallel success."
fi
done
exit 0

View File

@ -0,0 +1,27 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_broadcast_auto_parallel():
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_broadcast_auto_parallel.sh")
assert ret == 0