!20956 fix network issue
Merge pull request !20956 from JichenZhao/master
This commit is contained in:
commit
cb0792d32c
|
@ -55,7 +55,7 @@
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 分布式训练运行示例
|
# 分布式训练运行示例
|
||||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
sh scripts/run_distribute_train.sh /path/dataset /path/rank_table
|
||||||
|
|
||||||
# 单机训练运行示例
|
# 单机训练运行示例
|
||||||
sh scripts/run_standalone_train.sh /path/dataset
|
sh scripts/run_standalone_train.sh /path/dataset
|
||||||
|
@ -108,7 +108,7 @@ train.py和val.py中主要参数如下:
|
||||||
### 分布式训练
|
### 分布式训练
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
sh scripts/run_distribute_train.sh rank_size /path/dataset
|
sh scripts/run_distribute_train.sh /path/dataset /path/rank_table
|
||||||
```
|
```
|
||||||
|
|
||||||
上述shell脚本将在后台运行分布训练。可以通过`device[X]/train.log`文件查看结果。
|
上述shell脚本将在后台运行分布训练。可以通过`device[X]/train.log`文件查看结果。
|
||||||
|
@ -154,7 +154,7 @@ epoch time: 1104929.793 ms, per step time: 97.162 ms
|
||||||
同时,情确保传入的评估数据集路径为“IJB_release/IJBB/”或“IJB_release/IJBC/”。
|
同时,情确保传入的评估数据集路径为“IJB_release/IJBB/”或“IJB_release/IJBC/”。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh scripts/run_eval_ijbc.sh /path/evalset /path/ckpt
|
sh scripts/run_eval_ijbc.sh /path/evalset /path/ckpt target_name
|
||||||
```
|
```
|
||||||
|
|
||||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
|
@ -1,63 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
echo "Please run the script as: "
|
|
||||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
|
||||||
echo "For example: bash run.sh 8 path/dataset"
|
|
||||||
echo "It is better to use the absolute path."
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
|
|
||||||
RANK_SIZE=$1
|
|
||||||
DATA_PATH=$2
|
|
||||||
|
|
||||||
EXEC_PATH=$(pwd)
|
|
||||||
echo "$EXEC_PATH"
|
|
||||||
|
|
||||||
test_dist_8pcs()
|
|
||||||
{
|
|
||||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
|
||||||
export RANK_SIZE=8
|
|
||||||
}
|
|
||||||
|
|
||||||
test_dist_2pcs()
|
|
||||||
{
|
|
||||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
|
||||||
echo "$RANK_TABLE_FILE"
|
|
||||||
export RANK_SIZE=2
|
|
||||||
}
|
|
||||||
|
|
||||||
test_dist_${RANK_SIZE}pcs
|
|
||||||
|
|
||||||
for((i=0;i<RANK_SIZE;i++))
|
|
||||||
do
|
|
||||||
rm -rf device$i
|
|
||||||
mkdir device$i
|
|
||||||
cp -r ./src/ ./device$i
|
|
||||||
cp train.py ./device$i
|
|
||||||
cd ./device$i
|
|
||||||
export DEVICE_ID=$i
|
|
||||||
export RANK_ID=$i
|
|
||||||
echo "start training for device $i"
|
|
||||||
env > env$i.log
|
|
||||||
python train.py \
|
|
||||||
--data_url $DATA_PATH \
|
|
||||||
--device_num RANK_SIZE \
|
|
||||||
> train.log$i 2>&1 &
|
|
||||||
cd ../
|
|
||||||
done
|
|
||||||
echo "finish"
|
|
||||||
cd ../
|
|
|
@ -1,32 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
echo "Please run the script as: "
|
|
||||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
|
||||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
|
||||||
echo "It is better to use the absolute path."
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
|
|
||||||
EVAL_PATH=$1
|
|
||||||
CKPT_PATH=$2
|
|
||||||
|
|
||||||
python val.py \
|
|
||||||
--ckpt_url "$CKPT_PATH" \
|
|
||||||
--device_id 1 \
|
|
||||||
--eval_url "$EVAL_PATH" \
|
|
||||||
--target lfw,cfp_fp,agedb_30,calfw,cplfw \
|
|
||||||
> eval.log 2>&1 &
|
|
|
@ -1,34 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
echo "Please run the script as: "
|
|
||||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
|
||||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
|
||||||
echo "It is better to use the absolute path."
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
|
|
||||||
EVAL_PATH=$1
|
|
||||||
CKPT_PATH=$2
|
|
||||||
|
|
||||||
python eval_ijbc.py \
|
|
||||||
--model-prefix "$CKPT_PATH" \
|
|
||||||
--image-path "$EVAL_PATH" \
|
|
||||||
--result-dir ms1mv2_arcface_r100 \
|
|
||||||
--batch-size 128 \
|
|
||||||
--job ms1mv2_arcface_r100 \
|
|
||||||
--target IJBC \
|
|
||||||
--network iresnet100 \
|
|
||||||
> eval.log 2>&1 &
|
|
|
@ -1,30 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
echo "Please run the script as: "
|
|
||||||
echo "bash run.sh DATA_PATH"
|
|
||||||
echo "For example: bash run.sh path/MS1M"
|
|
||||||
echo "It is better to use the absolute path."
|
|
||||||
echo "=============================================================================================================="
|
|
||||||
|
|
||||||
# shellcheck disable=SC2034
|
|
||||||
DATA_PATH=$1
|
|
||||||
|
|
||||||
python train.py \
|
|
||||||
--data_url DATA_PATH \
|
|
||||||
--device_num 1 \
|
|
||||||
> train.log 2>&1 &
|
|
|
@ -16,31 +16,24 @@
|
||||||
|
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash run.sh RANK_SIZE DATA_PATH"
|
echo "bash run.sh DATA_PATH RANK_TABLE"
|
||||||
echo "For example: bash run.sh 8 path/dataset"
|
echo "For example: bash run.sh /path/dataset /path/rank_table"
|
||||||
echo "It is better to use the absolute path."
|
echo "It is better to use the absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
get_real_path(){
|
||||||
RANK_SIZE=$1
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
DATA_PATH=$2
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
RANK_SIZE=8
|
||||||
|
DATA_PATH=$(get_real_path $1)
|
||||||
|
RANK_TABLE=$(get_real_path $2)
|
||||||
|
|
||||||
EXEC_PATH=$(pwd)
|
EXEC_PATH=$(pwd)
|
||||||
echo "$EXEC_PATH"
|
echo "$EXEC_PATH"
|
||||||
|
export RANK_TABLE_FILE=$RANK_TABLE
|
||||||
test_dist_8pcs()
|
|
||||||
{
|
|
||||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
|
||||||
export RANK_SIZE=8
|
|
||||||
}
|
|
||||||
|
|
||||||
test_dist_2pcs()
|
|
||||||
{
|
|
||||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
|
||||||
echo "$RANK_TABLE_FILE"
|
|
||||||
export RANK_SIZE=2
|
|
||||||
}
|
|
||||||
|
|
||||||
test_dist_${RANK_SIZE}pcs
|
|
||||||
|
|
||||||
for((i=0;i<RANK_SIZE;i++))
|
for((i=0;i<RANK_SIZE;i++))
|
||||||
do
|
do
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
echo "Please run the script as: "
|
echo "Please run the script as: "
|
||||||
echo "bash run.sh EVAL_PATH CKPT_PATH"
|
echo "bash run.sh EVAL_PATH CKPT_PATH TARGET"
|
||||||
echo "For example: bash run.sh path/evalset path/ckpt"
|
echo "For example: bash run.sh path/evalset path/ckpt IJBC"
|
||||||
echo "It is better to use the absolute path."
|
echo "It is better to use the absolute path."
|
||||||
echo "=============================================================================================================="
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
@ -29,6 +29,6 @@ python eval_ijbc.py \
|
||||||
--result-dir ms1mv2_arcface_r100 \
|
--result-dir ms1mv2_arcface_r100 \
|
||||||
--batch-size 128 \
|
--batch-size 128 \
|
||||||
--job ms1mv2_arcface_r100 \
|
--job ms1mv2_arcface_r100 \
|
||||||
--target IJBC \
|
--target $3 \
|
||||||
--network iresnet100 \
|
--network iresnet100 \
|
||||||
> eval.log 2>&1 &
|
> eval.log 2>&1 &
|
||||||
|
|
|
@ -25,6 +25,6 @@ echo "==========================================================================
|
||||||
DATA_PATH=$1
|
DATA_PATH=$1
|
||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--data_url DATA_PATH \
|
--data_url $DATA_PATH \
|
||||||
--device_num 1 \
|
--device_num 1 \
|
||||||
> train.log 2>&1 &
|
> train.log 2>&1 &
|
|
@ -61,8 +61,8 @@ if __name__ == '__main__':
|
||||||
if args.distribute:
|
if args.distribute:
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
init()
|
init()
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
device_num = int(os.getenv('RANK_SIZE'))
|
||||||
context.set_auto_parallel_context(device_id=device_id,
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
if target == "GPU":
|
if target == "GPU":
|
||||||
|
|
|
@ -18,9 +18,9 @@ train
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import ast
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
|
@ -66,7 +66,7 @@ def parse_args():
|
||||||
parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
|
parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
|
||||||
parser.add_argument('--device_id', required=False, default=None, type=int, help='Location of training outputs.')
|
parser.add_argument('--device_id', required=False, default=None, type=int, help='Location of training outputs.')
|
||||||
parser.add_argument('--run_distribute', required=False, default=False, help='Location of training outputs.')
|
parser.add_argument('--run_distribute', required=False, default=False, help='Location of training outputs.')
|
||||||
parser.add_argument('--is_model_arts', required=False, default=False, help='Location of training outputs.')
|
parser.add_argument('--is_model_arts', type=ast.literal_eval, default=False, help='Location of training outputs.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
|
@ -54,15 +54,21 @@ bash wmt14_en_fr.sh
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# 运行训练示例
|
# 运行训练示例
|
||||||
python train.py > train.log 2>&1 &
|
cd ./scripts
|
||||||
|
bash run_standalone_train_ascend.sh [TRAIN_DATASET]
|
||||||
|
|
||||||
# 运行分布式训练示例
|
# 运行分布式训练示例
|
||||||
sh scripts/run_train.sh rank_table.json
|
cd ./scripts
|
||||||
|
bash run_distributed_train_ascend [RANK_TABLE] [TRAIN_DATASET]
|
||||||
|
|
||||||
# 运行评估示例
|
# 运行评估示例
|
||||||
python eval.py > eval.log 2>&1 &
|
cd ./scripts
|
||||||
或
|
bash run_standalone_eval_ascend.sh \
|
||||||
sh run_eval.sh
|
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
||||||
|
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
||||||
|
seq2seq/dataset_menu/vocab.bpe.32000 \
|
||||||
|
seq2seq/dataset_menu/bpe.32000 \
|
||||||
|
seq2seq/dataset_menu/newstest2014.fr
|
||||||
```
|
```
|
||||||
|
|
||||||
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
对于分布式训练,需要提前创建JSON格式的hccl配置文件。
|
||||||
|
@ -159,7 +165,8 @@ bash wmt14_en_fr.sh
|
||||||
- Ascend处理器环境运行
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash scripts/run_standalone_train_ascend.sh
|
cd ./scripts
|
||||||
|
bash run_standalone_train_ascend.sh [TRAIN_DATASET]
|
||||||
```
|
```
|
||||||
|
|
||||||
上述python命令将在后台运行,您可以通过scripts/train/log_seq2seq_network.log文件查看结果。loss值保存在scripts/train/loss.log
|
上述python命令将在后台运行,您可以通过scripts/train/log_seq2seq_network.log文件查看结果。loss值保存在scripts/train/loss.log
|
||||||
|
@ -171,7 +178,8 @@ bash wmt14_en_fr.sh
|
||||||
- Ascend处理器环境运行
|
- Ascend处理器环境运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bash scripts/run_distributed_train_ascend rank_table.json
|
cd ./scripts
|
||||||
|
bash run_distributed_train_ascend [RANK_TABLE] [TRAIN_DATASET]
|
||||||
```
|
```
|
||||||
|
|
||||||
上述shell脚本将在后台运行分布训练。您可以通过scripts/device[X]/log_seq2seq_network.log文件查看结果。loss值保存在scripts/device[X]/loss.log
|
上述shell脚本将在后台运行分布训练。您可以通过scripts/device[X]/log_seq2seq_network.log文件查看结果。loss值保存在scripts/device[X]/loss.log
|
||||||
|
@ -185,7 +193,8 @@ bash wmt14_en_fr.sh
|
||||||
- 在Ascend环境运行时评估,脚本示例如下
|
- 在Ascend环境运行时评估,脚本示例如下
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sh run_standalone_eval_ascend.sh \
|
cd ./scripts
|
||||||
|
bash run_standalone_eval_ascend.sh \
|
||||||
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
seq2seq/dataset_menu/newstest2014.en.mindrecord \
|
||||||
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
seq2seq/scripts/device0/text_translation/ckpt_0/seq2seq-8_3437.ckpt \
|
||||||
seq2seq/dataset_menu/vocab.bpe.32000 \
|
seq2seq/dataset_menu/vocab.bpe.32000 \
|
||||||
|
@ -208,7 +217,7 @@ bash wmt14_en_fr.sh
|
||||||
|
|
||||||
| 参数 | Ascend |
|
| 参数 | Ascend |
|
||||||
| ------------- | ------------------------------------------------------------ |
|
| ------------- | ------------------------------------------------------------ |
|
||||||
| 模型版本 | Inception V1 |
|
| 模型版本 | Seq2Seq |
|
||||||
| 资源 | Ascend 910, CPU 2.60GHz, 56核, 内存:314G |
|
| 资源 | Ascend 910, CPU 2.60GHz, 56核, 内存:314G |
|
||||||
| 上传日期 | 2021-3-29 |
|
| 上传日期 | 2021-3-29 |
|
||||||
| MindSpore版本 | 1.1.1 |
|
| MindSpore版本 | 1.1.1 |
|
||||||
|
@ -227,7 +236,7 @@ bash wmt14_en_fr.sh
|
||||||
|
|
||||||
| 参数 | Ascend |
|
| 参数 | Ascend |
|
||||||
| ------------- | -------------- |
|
| ------------- | -------------- |
|
||||||
| 模型版本 | Inception V1 |
|
| 模型版本 | Seq2Seq |
|
||||||
| 资源 | Ascend 910 |
|
| 资源 | Ascend 910 |
|
||||||
| 上传日期 | 2021-03-29 |
|
| 上传日期 | 2021-03-29 |
|
||||||
| MindSpore版本 | 1.1.1 |
|
| MindSpore版本 | 1.1.1 |
|
|
@ -48,6 +48,7 @@ do
|
||||||
export RANK_ID=$i
|
export RANK_ID=$i
|
||||||
export DEVICE_ID=$i
|
export DEVICE_ID=$i
|
||||||
python ../../train.py \
|
python ../../train.py \
|
||||||
|
--is_modelarts=False \
|
||||||
--config=${current_exec_path}/device${i}/config/config.json \
|
--config=${current_exec_path}/device${i}/config/config.json \
|
||||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network${i}.log 2>&1 &
|
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network${i}.log 2>&1 &
|
||||||
cd ${current_exec_path} || exit
|
cd ${current_exec_path} || exit
|
||||||
|
|
|
@ -41,6 +41,7 @@ cd ./train || exit
|
||||||
echo "start for training"
|
echo "start for training"
|
||||||
env > env.log
|
env > env.log
|
||||||
python train.py \
|
python train.py \
|
||||||
|
--is_modelarts=False \
|
||||||
--config=${current_exec_path}/train/config/config.json \
|
--config=${current_exec_path}/train/config/config.json \
|
||||||
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network.log 2>&1 &
|
--pre_train_dataset=$PRE_TRAIN_DATASET > log_seq2seq_network.log 2>&1 &
|
||||||
cd ..
|
cd ..
|
||||||
|
|
|
@ -81,7 +81,7 @@ class PredLogProbs(nn.Cell):
|
||||||
Tensor, log softmax output.
|
Tensor, log softmax output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self):
|
||||||
super(PredLogProbs, self).__init__()
|
super(PredLogProbs, self).__init__()
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||||
|
@ -172,7 +172,7 @@ class Seq2seqTraining(nn.Cell):
|
||||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||||
super(Seq2seqTraining, self).__init__()
|
super(Seq2seqTraining, self).__init__()
|
||||||
self.seq2seq = Seq2seqModel(config, is_training, use_one_hot_embeddings)
|
self.seq2seq = Seq2seqModel(config, is_training, use_one_hot_embeddings)
|
||||||
self.projection = PredLogProbs(config)
|
self.projection = PredLogProbs()
|
||||||
|
|
||||||
def construct(self, source_ids, source_mask, target_ids):
|
def construct(self, source_ids, source_mask, target_ids):
|
||||||
"""
|
"""
|
||||||
|
@ -285,7 +285,6 @@ class Seq2seqTrainOneStepWithLossScaleCell(nn.Cell):
|
||||||
self.get_status = P.NPUGetFloatStatus()
|
self.get_status = P.NPUGetFloatStatus()
|
||||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
|
||||||
self.base = Tensor(1, mstype.float32)
|
self.base = Tensor(1, mstype.float32)
|
||||||
self.less_equal = P.LessEqual()
|
self.less_equal = P.LessEqual()
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Train api."""
|
"""Train api."""
|
||||||
import os
|
import os
|
||||||
|
import ast
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import moxing as mox
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -25,7 +25,7 @@ from mindspore.nn.optim import Lamb
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||||
from mindspore.train.callback import LossMonitor, SummaryCollector
|
from mindspore.train.callback import LossMonitor, SummaryCollector
|
||||||
from mindspore import context, Parameter
|
from mindspore import context, Parameter
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication import management as MultiAscend
|
from mindspore.communication import management as MultiAscend
|
||||||
|
@ -42,28 +42,28 @@ from src.utils.optimizer import Adam
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
parser = argparse.ArgumentParser(description='Seq2seq train entry point.')
|
||||||
|
|
||||||
is_modelarts = False
|
parser.add_argument("--is_modelarts", type=ast.literal_eval, default=False, help="model config json file path.")
|
||||||
|
parser.add_argument("--data_url", type=str, default=None, help="pre-train dataset address.")
|
||||||
if is_modelarts:
|
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
||||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
|
||||||
parser.add_argument("--data_url", type=str, required=True, help="pre-train dataset address.")
|
|
||||||
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
|
|
||||||
|
|
||||||
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
|
||||||
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.is_modelarts:
|
||||||
|
import moxing as mox
|
||||||
context.set_context(
|
context.set_context(
|
||||||
mode=context.GRAPH_MODE,
|
mode=context.GRAPH_MODE,
|
||||||
save_graphs=True,
|
save_graphs=True,
|
||||||
device_target="Ascend",
|
device_target="Ascend",
|
||||||
reserve_class_name_in_scope=True)
|
reserve_class_name_in_scope=True)
|
||||||
|
|
||||||
|
|
||||||
def get_config(config):
|
def get_config(config):
|
||||||
config = Seq2seqConfig.from_json_file(config)
|
config = Seq2seqConfig.from_json_file(config)
|
||||||
config.compute_type = mstype.float16
|
config.compute_type = mstype.float16
|
||||||
config.dtype = mstype.float32
|
config.dtype = mstype.float32
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _train(model, config: Seq2seqConfig,
|
def _train(model, config: Seq2seqConfig,
|
||||||
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
|
pre_training_dataset=None, fine_tune_dataset=None, test_dataset=None,
|
||||||
callbacks: list = None):
|
callbacks: list = None):
|
||||||
|
@ -333,17 +333,16 @@ def _check_args(config):
|
||||||
if not isinstance(config, str):
|
if not isinstance(config, str):
|
||||||
raise ValueError("`config` must be type of str.")
|
raise ValueError("`config` must be type of str.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_rank_size = os.getenv('RANK_SIZE')
|
_rank_size = os.getenv('RANK_SIZE')
|
||||||
args, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
if is_modelarts:
|
|
||||||
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
|
||||||
_config.pre_train_dataset = '/cache/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord'
|
|
||||||
_config.ckpt_path = '/cache/train_output/'
|
|
||||||
|
|
||||||
_check_args(args.config)
|
_check_args(args.config)
|
||||||
_config = get_config(args.config)
|
_config = get_config(args.config)
|
||||||
|
if args.is_modelarts:
|
||||||
|
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset_menu/')
|
||||||
|
_config.pre_train_dataset = '/cache/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord'
|
||||||
|
_config.ckpt_path = '/cache/train_output/'
|
||||||
_config.pre_train_dataset = args.pre_train_dataset
|
_config.pre_train_dataset = args.pre_train_dataset
|
||||||
|
|
||||||
set_seed(_config.random_seed)
|
set_seed(_config.random_seed)
|
||||||
|
@ -353,5 +352,5 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
train_single(_config)
|
train_single(_config)
|
||||||
|
|
||||||
if is_modelarts:
|
if args.is_modelarts:
|
||||||
mox.file.copy_parallel(src_url='/cache/train_output/', dst_url=args.train_url)
|
mox.file.copy_parallel(src_url='/cache/train_output/', dst_url=args.train_url)
|
||||||
|
|
Loading…
Reference in New Issue