textcnn gpu
This commit is contained in:
parent
562786bc44
commit
f152eb767f
|
@ -57,16 +57,42 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
```python
|
||||
# run training example
|
||||
# need set config_path in config.py file and set data_path in yaml file
|
||||
python train.py > train.log 2>&1 &
|
||||
# need set config_path in config.py file and set data_path in yaml file
|
||||
python train.py --config_path [CONFIG_PATH] \
|
||||
--device_target [TARGET] \
|
||||
--data_path [DATA_PATH]> train.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_train.sh dataset
|
||||
sh scripts/run_train.sh [DATASET]
|
||||
|
||||
# run evaluation example
|
||||
# need set config_path in config.py file and set data_path, checkpoint_file_path in yaml file
|
||||
python eval.py > eval.log 2>&1 &
|
||||
python eval.py --config_path [CONFIG_PATH] \
|
||||
--device_target [TARGET] \
|
||||
--checkpoint_file_path [CKPT_FILE] \
|
||||
--data_path [DATA_PATH] > eval.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_eval.sh checkpoint_file_path dataset
|
||||
sh scripts/run_eval.sh [CKPT_FILE] [DATASET]
|
||||
```
|
||||
|
||||
- running on GPU
|
||||
|
||||
```python
|
||||
# run training example
|
||||
# need set config_path in config.py file and set data_path in yaml file
|
||||
python train.py --config_path [CONFIG_PATH] \
|
||||
--device_target GPU \
|
||||
--data_path [DATA_PATH]> train.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_train_gpu.sh [DATASET] [DATA_PATH]
|
||||
|
||||
# run evaluation example
|
||||
# need set config_path in config.py file and set data_path, checkpoint_file_path in yaml file
|
||||
python eval.py --config_path [CONFIG_PATH] \
|
||||
--device_target GPU \
|
||||
--checkpoint_file_path [CKPT_FILE] \
|
||||
--data_path [DATA_PATH] > eval.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_eval.sh [CKPT_FILE] [DATASET] [DATA_PATH]
|
||||
```
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training and evaluation as follows:
|
||||
|
@ -114,6 +140,8 @@ If you want to run in modelarts, please check the official documentation of [mod
|
|||
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
||||
│ ├── run_train_cpu.sh // shell script for training on CPU
|
||||
│ ├── run_eval_cpu.sh // shell script for evaluation on CPU
|
||||
│ ├── run_train_gpu.sh // shell script for training on GPU
|
||||
│ ├── run_eval_gpu.sh // shell script for evaluation on GPU
|
||||
├── src
|
||||
│ ├── dataset.py // Processing dataset
|
||||
│ ├── textcnn.py // textcnn architecture
|
||||
|
@ -160,10 +188,25 @@ For more configuration details, please refer the script `*.yaml`.
|
|||
- running on Ascend/CPU
|
||||
|
||||
```python
|
||||
# need set config_path in config.py file and set data_path in yaml file
|
||||
python train.py > train.log 2>&1 &
|
||||
# `CONFIG_PATH` `DATA_PATH` `DATASET` `DEVICE_TARGET` parameters need to be passed externally or modified yaml file
|
||||
# `DATASET` must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
python train.py --config_path [CONFIG_PATH] \
|
||||
--device_target [DEVICE_TARGET] \
|
||||
--data_path [DATA_PATH]> train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_train.sh dataset
|
||||
bash scripts/run_train.sh [DATASET]
|
||||
```
|
||||
|
||||
- running on GPU
|
||||
|
||||
```python
|
||||
# `CONFIG_PATH` `DATA_PATH` `DATASET` parameters need to be passed externally or modified yaml file
|
||||
# `DATASET` must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
python train.py --config_path [CONFIG_PATH] \
|
||||
--device_target GPU \
|
||||
--data_path [DATA_PATH]> train.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_train.sh [DATASET] [DATA_PATH]
|
||||
```
|
||||
|
||||
The python command above will run in the background, you can view the results through the file `train.log`.
|
||||
|
@ -186,10 +229,27 @@ For more configuration details, please refer the script `*.yaml`.
|
|||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/textcnn/ckpt/train_textcnn.ckpt".
|
||||
|
||||
```python
|
||||
# need set config_path and set data_path in yaml file, checkpoint_file_path in yaml file
|
||||
python eval.py > eval.log 2>&1 &
|
||||
# `CONFIG_PATH` `DEVICE_TARGET` `CKPT_FILE` `DATA_PATH` `DATASET` parameters need to be passed externally or modified yaml file
|
||||
# `DATASET` must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
python eval.py --config_path [CONFIG_PATH] \
|
||||
--device_target [DEVICE_TARGET] \
|
||||
--checkpoint_file_path [CKPT_FILE] \
|
||||
--data_path [DATA_PATH] > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval.sh checkpoint_file_path dataset
|
||||
bash scripts/run_eval.sh [CKPT_FILE] [DATASET]
|
||||
```
|
||||
|
||||
- evaluation on movie review dataset when running on GPU
|
||||
|
||||
```python
|
||||
# `CONFIG_PATH` `CKPT_FILE` `DATA_PATH` `DATASET` parameters need to be passed externally or modified yaml file
|
||||
# `DATASET` must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
python eval.py --config_path [CONFIG_PATH] \
|
||||
--device_target GPU \
|
||||
--checkpoint_file_path [CKPT_FILE] \
|
||||
--data_path [DATA_PATH] > eval.log 2>&1 &
|
||||
OR
|
||||
bash scripts/run_eval_gpu.sh [CKPT_FILE] [DATASET] [DATA_PATH]
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The accuracy of the test dataset will be as follows:
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "Usage: bash run_train_gpu.sh [CKPT_FILE] [DATASET] [DATA_PATH]
|
||||
DATASET must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
data_path="$(get_real_path $3)"
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
if [ $2 == "MR" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../mr_config.yaml"
|
||||
elif [ $2 == "SUBJ" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../subj_config.yaml"
|
||||
elif [ $2 == "SST2" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../sst2_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$2
|
||||
fi
|
||||
python ${BASE_PATH}/../eval.py \
|
||||
--device_target="GPU" \
|
||||
--checkpoint_file_path=$1 \
|
||||
--dataset=$dataset_type \
|
||||
--data_path=$data_path \
|
||||
--config_path=$CONFIG_FILE > eval.log 2>&1 &
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# -ne 2 ]; then
|
||||
echo "Usage: bash run_train_gpu.sh [DATASET] [DATA_PATH]
|
||||
DATASET must choose from ['MR', 'SUBJ', 'SST2']"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
data_path="$(get_real_path $2)"
|
||||
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
if [ $1 == "MR" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../mr_config.yaml"
|
||||
elif [ $1 == "SUBJ" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../subj_config.yaml"
|
||||
elif [ $1 == "SST2" ]; then
|
||||
CONFIG_FILE="${BASE_PATH}/../sst2_config.yaml"
|
||||
else
|
||||
echo "error: the selected dataset is not in supported set{MR, SUBJ, SST2}"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$1
|
||||
fi
|
||||
|
||||
python ${BASE_PATH}/../train.py \
|
||||
--device_target="GPU" \
|
||||
--dataset=$dataset_type \
|
||||
--data_path=$data_path \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--output_path='./output' > train.log 2>&1 &
|
|
@ -33,9 +33,11 @@ from src.textcnn import TextCNN
|
|||
from src.textcnn import SoftmaxCrossEntropyExpand
|
||||
from src.dataset import MovieReview, SST2, Subjectivity
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
config.checkpoint_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train_net():
|
||||
'''train net'''
|
||||
|
@ -69,12 +71,12 @@ def train_net():
|
|||
load_param_into_net(net, param_dict)
|
||||
|
||||
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \
|
||||
learning_rate=learning_rate, weight_decay=float(config.weight_decay))
|
||||
learning_rate=learning_rate, weight_decay=float(config.weight_decay))
|
||||
loss = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size*batch_num/2),
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=int(config.epoch_size * batch_num / 2),
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
|
||||
|
@ -83,5 +85,6 @@ def train_net():
|
|||
model.train(config.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_net()
|
||||
|
|
Loading…
Reference in New Issue