forked from mindspore-Ecosystem/mindspore
!20956 fix network issue
Merge pull request !20956 from JichenZhao/master
This commit is contained in:
commit
cb0792d32c
|
@ -55,7 +55,7 @@
|
|||
|
||||
```python
|
||||
# 分布式训练运行示例
|
||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
||||
sh scripts/run_distribute_train.sh /path/dataset /path/rank_table
|
||||
|
||||
# 单机训练运行示例
|
||||
sh scripts/run_standalone_train.sh /path/dataset
|
||||
|
@ -108,7 +108,7 @@ train.py和val.py中主要参数如下:
|
|||
### 分布式训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
||||
sh scripts/run_distribute_train.sh /path/dataset /path/rank_table
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。可以通过`device[X]/train.log`文件查看结果。
|
||||
|
@ -154,7 +154,7 @@ epoch time: 1104929.793 ms, per step time: 97.162 ms
|
|||
同时,情确保传入的评估数据集路径为“IJB_release/IJBB/”或“IJB_release/IJBC/”。
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval_ijbc.sh /path/evalset /path/ckpt
|
||||
sh scripts/run_eval_ijbc.sh /path/evalset /path/ckpt target_name
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
|
@ -1,63 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
||||
echo "For example: bash run.sh 8 path/dataset"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATA_PATH=$2
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
echo "$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp -r ./src/ ./device$i
|
||||
cp train.py ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py \
|
||||
--data_url $DATA_PATH \
|
||||
--device_num RANK_SIZE \
|
||||
> train.log$i 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
echo "finish"
|
||||
cd ../
|
|
@ -1,32 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python val.py \
|
||||
--ckpt_url "$CKPT_PATH" \
|
||||
--device_id 1 \
|
||||
--eval_url "$EVAL_PATH" \
|
||||
--target lfw,cfp_fp,agedb_30,calfw,cplfw \
|
||||
> eval.log 2>&1 &
|
|
@ -1,34 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EVAL_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
|
||||
python eval_ijbc.py \
|
||||
--model-prefix "$CKPT_PATH" \
|
||||
--image-path "$EVAL_PATH" \
|
||||
--result-dir ms1mv2_arcface_r100 \
|
||||
--batch-size 128 \
|
||||
--job ms1mv2_arcface_r100 \
|
||||
--target IJBC \
|
||||
--network iresnet100 \
|
||||
> eval.log 2>&1 &
|
|
@ -1,30 +0,0 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH"
|
||||
echo "For example: bash run.sh path/MS1M"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
# shellcheck disable=SC2034
|
||||
DATA_PATH=$1
|
||||
|
||||
python train.py \
|
||||
--data_url DATA_PATH \
|
||||
--device_num 1 \
|
||||
> train.log 2>&1 &
|
|
@ -16,31 +16,24 @@
|
|||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
||||
echo "For example: bash run.sh 8 path/dataset"
|
||||
echo "bash run.sh DATA_PATH RANK_TABLE"
|
||||
echo "For example: bash run.sh /path/dataset /path/rank_table"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATA_PATH=$2
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
RANK_SIZE=8
|
||||
DATA_PATH=$(get_real_path $1)
|
||||
RANK_TABLE=$(get_real_path $2)
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
echo "$RANK_TABLE_FILE"
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
|
|
|
@ -15,8 +15,8 @@
|
|||
# ============================================================================
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
||||
echo "bash run.sh EVAL_PATH CKPT_PATH TARGET"
|
||||
echo "For example: bash run.sh path/evalset path/ckpt IJBC"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
|
@ -29,6 +29,6 @@ python eval_ijbc.py \
|
|||
--result-dir ms1mv2_arcface_r100 \
|
||||
--batch-size 128 \
|
||||
--job ms1mv2_arcface_r100 \
|
||||
--target IJBC \
|
||||
--target $3 \
|
||||
--network iresnet100 \
|
||||
> eval.log 2>&1 &
|
||||
|
|
|
@ -25,6 +25,6 @@ echo "==========================================================================
|
|||
DATA_PATH=$1
|
||||
|
||||
python train.py \
|
||||
--data_url DATA_PATH \
|
||||
--data_url $DATA_PATH \
|
||||
--device_num 1 \
|
||||
> train.log 2>&1 &
|
|
@ -61,8 +61,8 @@ if __name__ == '__main__':
|
|||
if args.distribute:
|
||||
if target == "Ascend":
|
||||
init()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_auto_parallel_context(device_id=device_id,
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
if target == "GPU":
|
||||
|
|
|
@ -18,9 +18,9 @@ train
|
|||
from __future__ import division
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
|
@ -66,7 +66,7 @@ def parse_args():
|
|||
parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
|
||||
parser.add_argument('--device_id', required=False, default=None, type=int, help='Location of training outputs.')
|
||||
parser.add_argument('--run_distribute', required=False, default=False, help='Location of training outputs.')
|
||||
parser.add_argument('--is_model_arts', required=False, default=False, help='Location of training outputs.')
|
||||
parser.add_argument('--is_model_arts', type=ast.literal_eval, default=False, help='Location of training outputs.')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
|
|
@ -54,15 +54,21 @@ bash wmt14_en_fr.sh
|
|||
|
||||
```python
|
||||
# 运行训练示例
|
||||
python train.py > train.log 2>&1 &
|
||||
cd ./scripts
|
||||
bash run_standalone_train_ascend.sh [TRAIN_DATASET]
|
||||
|
||||
# 运行分布式训练示例
|
||||
sh scripts/run_train.sh rank_table.json
|
||||
cd ./scripts
|
||||
bash run_distributed_train_ascend [RANK_TABLE] [TRAIN_DATASET]
|
||||
|
||||
# 运行评估示例
|
||||
python eval.py > eval.log 2>&1 &
|
||||
或
|
||||
sh run_eval.sh
|
||||
cd ./scripts
|
||||
bash run_standalone_eval_ascend.sh \
|
||||
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
||||
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
||||
seq2seq/dataset_menu/vocab.bpe.32000 \
|
||||
seq2seq/dataset_menu/bpe.32000 \
|
||||
seq2seq/dataset_menu/newstest2014.fr
|
||||
```
|
||||
|
||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||
|
@ -159,7 +165,8 @@ bash wmt14_en_fr.sh
|
|||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_standalone_train_ascend.sh
|
||||
cd ./scripts
|
||||
bash run_standalone_train_ascend.sh [TRAIN_DATASET]
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过scripts/train/log_seq2seq_network.log文件查看结果。loss值保存在scripts/train/loss.log
|
||||
|
@ -171,7 +178,8 @@ bash wmt14_en_fr.sh
|
|||
- Ascend处理器环境运行
|
||||
|
||||
```bash
|
||||
bash scripts/run_distributed_train_ascend rank_table.json
|
||||
cd ./scripts
|
||||
bash run_distributed_train_ascend [RANK_TABLE] [TRAIN_DATASET]
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过scripts/device[X]/log_seq2seq_network.log文件查看结果。loss值保存在scripts/device[X]/loss.log
|
||||
|
@ -185,7 +193,8 @@ bash wmt14_en_fr.sh
|
|||
- 在Ascend环境运行时评估,脚本示例如下
|
||||
|
||||
```bash
|
||||
sh run_standalone_eval_ascend.sh \
|
||||
cd ./scripts
|
||||
bash run_standalone_eval_ascend.sh \
|
||||
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
||||
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
||||
seq2seq/dataset_menu/vocab.bpe.32000 \
|
||||
|
@ -208,7 +217,7 @@ bash wmt14_en_fr.sh
|
|||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | ------------------------------------------------------------ |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 模型版本 | Seq2Seq |
|
||||
| 资源 | Ascend 910, CPU 2.60GHz, 56核, 内存:314G |
|
||||
| 上传日期 | 2021-3-29 |
|
||||
| MindSpore版本 | 1.1.1 |
|
||||
|
@ -227,7 +236,7 @@ bash wmt14_en_fr.sh
|
|||
|
||||
| 参数 | Ascend |
|
||||
| ------------- | -------------- |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 模型版本 | Seq2Seq |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021-03-29 |
|
||||
| MindSpore版本 | 1.1.1 |
|
|
@ -48,6 +48,7 @@ do
|
|||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python ../../train.py \
|
||||
--is_modelarts=False \
|
||||
--config=${current_exec_path}/device${i}/config/config.json \
|
||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network${i}.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
|
|
|
@ -41,6 +41,7 @@ cd ./train || exit
|
|||
echo "start for training"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--is_modelarts=False \
|
||||
--config=${current_exec_path}/train/config/config.json \
|
||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -81,7 +81,7 @@ class PredLogProbs(nn.Cell):
|
|||
Tensor, log softmax output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self):
|
||||
super(PredLogProbs, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
|
@ -172,7 +172,7 @@ class Seq2seqTraining(nn.Cell):
|
|||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(Seq2seqTraining, self).__init__()
|
||||
self.seq2seq = Seq2seqModel(config, is_training, use_one_hot_embeddings)
|
||||
self.projection = PredLogProbs(config)
|
||||
self.projection = PredLogProbs()
|
||||
|
||||
def construct(self, source_ids, source_mask, target_ids):
|
||||
"""
|
||||
|
@ -285,7 +285,6 @@ class Seq2seqTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# ============================================================================
|
||||
"""Train api."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
import numpy as np
|
||||
import moxing as mox
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -25,7 +25,7 @@ from mindspore.nn.optim import Lamb
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.train.callback import LossMonitor, SummaryCollector
|
||||
from mindspore.train.callback import LossMonitor, SummaryCollector
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication import management as MultiAscend
|
||||
|
@ -42,28 +42,28 @@ from src.utils.optimizer import Adam
|
|||
|
||||
parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
||||
|
||||
is_modelarts = False
|
||||
|
||||
if is_modelarts:
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
parser.add_argument("--data_url", type=str, required=True, help="pre-train dataset address.")
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||
|
||||
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
|
||||
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
|
||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.is_modelarts:
|
||||
import moxing as mox
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=True,
|
||||
device_target="Ascend",
|
||||
reserve_class_name_in_scope=True)
|
||||
|
||||
|
||||
def get_config(config):
|
||||
config = Seq2seqConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
|
||||
def _train(model, config: Seq2seqConfig,
|
||||
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
|
||||
callbacks: list = None):
|
||||
|
@ -333,17 +333,16 @@ def _check_args(config):
|
|||
if not isinstance(config, str):
|
||||
raise ValueError("`config` must be type of str.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_rank_size = os.getenv('RANK_SIZE')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
||||
_config.pre_train_dataset = '/cache/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord'
|
||||
_config.ckpt_path = '/cache/train_output/'
|
||||
|
||||
_check_args(args.config)
|
||||
_config = get_config(args.config)
|
||||
if args.is_modelarts:
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
||||
_config.pre_train_dataset = '/cache/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord'
|
||||
_config.ckpt_path = '/cache/train_output/'
|
||||
_config.pre_train_dataset = args.pre_train_dataset
|
||||
|
||||
set_seed(_config.random_seed)
|
||||
|
@ -353,5 +352,5 @@ if __name__ == '__main__':
|
|||
else:
|
||||
train_single(_config)
|
||||
|
||||
if is_modelarts:
|
||||
if args.is_modelarts:
|
||||
mox.file.copy_parallel(src_url='/cache/train_output/', dst_url=args.train_url)
|
||||
|
|
Loading…
Reference in New Issue