textrcnn
This commit is contained in:
parent
50304d82fb
commit
bced1e6d4e
|
@ -0,0 +1,60 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
pos_dir: 'data/rt-polaritydata/rt-polarity.pos'
|
||||
neg_dir: 'data/rt-polaritydata/rt-polarity.neg'
|
||||
num_epochs: 10
|
||||
lstm_num_epochs: 15
|
||||
batch_size: 64
|
||||
cell: 'gru'
|
||||
ckpt_folder_path: './ckpt'
|
||||
preprocess_path: './preprocess'
|
||||
preprocess: 'false'
|
||||
data_root: './data/'
|
||||
lr: 0.001 # 1e-3
|
||||
lstm_lr_init: 0.002 # 2e-3
|
||||
lstm_lr_end: 0.0005 # 5e-4
|
||||
lstm_lr_max: 0.003 # 3e-3
|
||||
lstm_lr_warm_up_epochs: 2
|
||||
lstm_lr_adjust_epochs: 9
|
||||
emb_path: './word2vec'
|
||||
embed_size: 300
|
||||
save_checkpoint_steps: 149
|
||||
keep_checkpoint_max: 10
|
||||
ckpt_path: ''
|
||||
|
||||
# Export related
|
||||
ckpt_file: ''
|
||||
file_name: 'textrcnn'
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
# ModelArts related
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
# Export related
|
||||
ckpt_file: "textrcnn ckpt file."
|
||||
file_name: "textrcnn output file name."
|
||||
file_format: "file format, choose from MINDIR or AIR"
|
||||
|
||||
---
|
||||
file_format: ["AIR", "MINDIR"]
|
||||
device_target: ["Ascend"]
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""model evaluation script"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -23,25 +22,32 @@ from mindspore import Tensor
|
|||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import LossMonitor
|
||||
# from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.textrcnn import textrcnn
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config as cfg
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='textrcnn')
|
||||
parser.add_argument('--ckpt_path', type=str)
|
||||
args = parser.parse_args()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
cfg.ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg.ckpt_path)
|
||||
cfg.preprocess_path = cfg.data_path
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
'''eval function.'''
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
|
@ -49,12 +55,15 @@ if __name__ == '__main__':
|
|||
cell=cfg.cell, batch_size=cfg.batch_size)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
eval_net = nn.WithEvalCell(network, loss, True)
|
||||
loss_cb = LossMonitor()
|
||||
# loss_cb = LossMonitor()
|
||||
print("============== Starting Testing ==============")
|
||||
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
param_dict = load_checkpoint(cfg.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
network.set_train(False)
|
||||
model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -14,26 +14,16 @@
|
|||
# ============================================================================
|
||||
"""textrcnn export ckpt file to mindir/air"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.textrcnn import textrcnn
|
||||
from src.config import textrcnn_cfg as config
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
parser = argparse.ArgumentParser(description="textrcnn")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="textrcnn ckpt file.")
|
||||
parser.add_argument("--file_name", type=str, default="textrcnn", help="textrcnn output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"],
|
||||
default="MINDIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def run_export():
|
||||
'''export function.'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
# define net
|
||||
embedding_table = np.loadtxt(os.path.join(config.preprocess_path, "weight.txt")).astype(np.float32)
|
||||
|
||||
|
@ -41,9 +31,12 @@ if __name__ == "__main__":
|
|||
cell=config.cell, batch_size=config.batch_size)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
image = Tensor(np.ones([config.batch_size, 50], np.int32))
|
||||
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
||||
export(net, image, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -18,7 +18,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.model_utils.config import config as cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='postprocess')
|
||||
parser.add_argument('--label_path', type=str, default="./preprocess_Result/label_ids.npy")
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.model_utils.config import config as cfg
|
||||
from src.dataset import create_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='preprocess')
|
||||
|
|
|
@ -65,6 +65,8 @@ Dataset used: [Sentence polarity dataset v1.0](http://www.cs.cornell.edu/people/
|
|||
|
||||
- Running on Ascend
|
||||
|
||||
If you are running the scripts for the first time and , you must set the parameter 'preprocess' to 'true' in the `default_config.yaml` and run training to get the folder 'preprocess' containing data。
|
||||
|
||||
```python
|
||||
# run training
|
||||
DEVICE_ID=7 python train.py
|
||||
|
@ -77,71 +79,121 @@ DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
|
|||
bash scripts/run_eval.sh
|
||||
```
|
||||
|
||||
- Running on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
|
||||
You have to prepare the folder 'preprocess'.
|
||||
|
||||
You can change the file name of 'requirements.txt' to 'pip-requirements.txt' for installing some third party libraries automatically on ModelArts.
|
||||
|
||||
- Training standalone on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/textrcnn" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/textrcnn/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/textrcnn/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set ”cell: 'lstm'“(Default is 'gru'. if you want to use lstm, you can do this step)
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step)
|
||||
# (6) Upload the dataset(the folder 'preprocess') to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/textrcnn" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/textrcnn/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/textrcnn/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set ”cell: 'lstm'“(Default is 'gru'. If you want to use lstm, you can do this step)
|
||||
# 3. Set ”ckpt_path: './{path}/*.ckpt'“(The *.ckpt file must under the folder 'textrcnn')
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Set ”enable_modelarts=True“
|
||||
# 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step)
|
||||
# 3. Set ”ckpt_path=./{path}/*.ckpt“(The *.ckpt file must under the folder 'textrcnn')
|
||||
# (6) Upload the dataset(the folder 'preprocess') to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```python
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── README.md // descriptions about all the models
|
||||
├── textrcnn
|
||||
├── README.md // descriptions about TextRCNN
|
||||
├── data_src
|
||||
│ ├──rt-polaritydata // directory to save the source data
|
||||
│ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
|
||||
├── readme.md // descriptions about TextRCNN
|
||||
├── ascend310_infer // application for 310 inference
|
||||
├── scripts
|
||||
│ ├──run_train.sh // shell script for train on Ascend
|
||||
│ ├──run_eval.sh // shell script for evaluation on Ascend
|
||||
│ ├──sample.txt // example shell to run the above the two scripts
|
||||
│ ├──run_infer_310.sh // shell script for 310 infer
|
||||
│ └──run_eval.sh // shell script for evaluation on Ascend
|
||||
├── src
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──textrcnn.py // textrcnn architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
├── train.py // training script
|
||||
├── export.py // export script
|
||||
├── eval.py // evaluation script
|
||||
├── data_helpers.py // dataset split script
|
||||
├── sample.txt // the shell to train and eval the model without scripts
|
||||
│ ├──model_utils
|
||||
│ │ ├──config.py // parsing parameter configuration file of "*.yaml"
|
||||
│ │ ├──device_adapter.py // local or ModelArts training
|
||||
│ │ ├──local_adapter.py // get related environment variables in local training
|
||||
│ │ └──moxing_adapter.py // get related environment variables in ModelArts training
|
||||
│ ├──dataset.py // creating dataset
|
||||
│ ├──textrcnn.py // textrcnn architecture
|
||||
│ └──utils.py // function related to learning rate
|
||||
├── data_helpers.py // dataset split script
|
||||
├── default_config.yaml // parameter configuration
|
||||
├── eval.py // evaluation script
|
||||
├── export.py // export script
|
||||
├── mindspore_hub_conf.py // mindspore hub interface
|
||||
├── postprocess.py // 310infer postprocess script
|
||||
├── preprocess.py // dataset generation script
|
||||
├── requirements.txt // some third party libraries that need to be installed
|
||||
├── sample.txt // the shell to train and eval the model without '*.sh'
|
||||
└── train.py // training script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in `default_config.yaml`
|
||||
|
||||
- config for Textrcnn, Sentence polarity dataset v1.0.
|
||||
|
||||
```python
|
||||
'num_epochs': 10, # total training epochs
|
||||
'lstm_num_epochs': 15, # total training epochs when using lstm
|
||||
'batch_size': 64, # training batch size
|
||||
'cell': 'gru', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
|
||||
'ckpt_folder_path': './ckpt', # the path to save the checkpoints
|
||||
'preprocess_path': './preprocess', # the directory to save the processed data
|
||||
'preprocess' : 'false', # whethere to preprocess the data
|
||||
'data_path': './data/', # the path to store the splited data
|
||||
'lr': 1e-3, # the training learning rate
|
||||
'lstm_lr_init': 2e-3, # learning rate initial value when using lstm
|
||||
'lstm_lr_end': 5e-4, # learning rate end value when using lstm
|
||||
'lstm_lr_max': 3e-3, # learning eate max value when using lstm
|
||||
'lstm_lr_warm_up_epochs': 2 # warm up epoch num when using lstm
|
||||
'lstm_lr_adjust_epochs': 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm
|
||||
'emb_path': './word2vec', # the directory to save the embedding file
|
||||
'embed_size': 300, # the dimension of the word embedding
|
||||
'save_checkpoint_steps': 149, # per step to save the checkpoint
|
||||
'keep_checkpoint_max': 10 # max checkpoints to save
|
||||
num_epochs: 10 # total training epochs
|
||||
lstm_num_epochs: 15 # total training epochs when using lstm
|
||||
batch_size: 64 # training batch size
|
||||
cell: 'gru' # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
|
||||
ckpt_folder_path: './ckpt' # the path to save the checkpoints
|
||||
preprocess_path: './preprocess' # the directory to save the processed data
|
||||
preprocess: 'false' # whethere to preprocess the data
|
||||
data_path: './data/' # the path to store the splited data
|
||||
lr: 0.001 # 1e-3 # the training learning rate
|
||||
lstm_lr_init: 0.002 # 2e-3 # learning rate initial value when using lstm
|
||||
lstm_lr_end: 0.0005 # 5e-4 # learning rate end value when using lstm
|
||||
lstm_lr_max: 0.003 # 3e-3 # learning eate max value when using lstm
|
||||
lstm_lr_warm_up_epochs: 2 # warm up epoch num when using lstm
|
||||
lstm_lr_adjust_epochs: 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm
|
||||
emb_path: './word2vec' # the directory to save the embedding file
|
||||
embed_size: 300 # the dimension of the word embedding
|
||||
save_checkpoint_steps: 149 # per step to save the checkpoint
|
||||
keep_checkpoint_max: 10 # max checkpoints to save
|
||||
ckpt_path: '' # relative path of '*.ckpt' to be evaluated relative to the eval.py
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
| Model | MindSpore + Ascend | TensorFlow+GPU |
|
||||
| -------------------------- | ----------------------------- | ------------------------- |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Version | 1.0.1 | 1.4.0 |
|
||||
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
|
||||
| batch_size | 64 | 64 |
|
||||
| Accuracy | 0.78 | 0.78 |
|
||||
| Speed | 35ms/step | 77ms/step |
|
||||
|
||||
## Inference Process
|
||||
|
||||
### [Export MindIR](#contents)
|
||||
|
@ -173,6 +225,17 @@ Inference result is saved in current path, you can find result like this in acc.
|
|||
============== Accuracy:{} ============== 0.8008
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
| Model | MindSpore + Ascend | TensorFlow+GPU |
|
||||
| -------------------------- | ----------------------------- | ------------------------- |
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| Version | 1.0.1 | 1.4.0 |
|
||||
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
|
||||
| batch_size | 64 | 64 |
|
||||
| Accuracy | 0.78 | 0.78 |
|
||||
| Speed | 35ms/step | 77ms/step |
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||
|
|
|
@ -13,6 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 2 ] && [ $# != 1 ]; then
|
||||
echo "Usage: sh run_eval.sh [CKPT_PATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 2 ]; then
|
||||
export DEVICE_ID=$2
|
||||
else
|
||||
export DEVICE_ID=0
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
|
|
|
@ -13,6 +13,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
if [ $# != 1 ] && [ $# != 0 ]; then
|
||||
echo "Usage: sh run_train.sh [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ $# == 1 ]; then
|
||||
export DEVICE_ID=$1
|
||||
else
|
||||
export DEVICE_ID=0
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# LSTM CONFIG
|
||||
textrcnn_cfg = edict({
|
||||
'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
|
||||
'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
|
||||
'num_epochs': 10,
|
||||
'lstm_num_epochs': 15,
|
||||
'batch_size': 64,
|
||||
'cell': 'gru',
|
||||
'ckpt_folder_path': './ckpt',
|
||||
'preprocess_path': './preprocess',
|
||||
'preprocess': 'false',
|
||||
'data_path': './data/',
|
||||
'lr': 1e-3,
|
||||
'lstm_lr_init': 2e-3,
|
||||
'lstm_lr_end': 5e-4,
|
||||
'lstm_lr_max': 3e-3,
|
||||
'lstm_lr_warm_up_epochs': 2,
|
||||
'lstm_lr_adjust_epochs': 9,
|
||||
'emb_path': './word2vec',
|
||||
'embed_size': 300,
|
||||
'save_checkpoint_steps': 149,
|
||||
'keep_checkpoint_max': 10,
|
||||
})
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -63,6 +63,10 @@ class textrcnn(nn.Cell):
|
|||
self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
|
||||
self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
|
||||
self.rnnW_fw.to_float(mstype.float16)
|
||||
self.rnnU_fw.to_float(mstype.float16)
|
||||
self.rnnW_bw.to_float(mstype.float16)
|
||||
self.rnnU_bw.to_float(mstype.float16)
|
||||
|
||||
if cell == "gru":
|
||||
self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
|
||||
|
|
|
@ -25,22 +25,31 @@ from mindspore.nn.metrics import Accuracy
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import textrcnn_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.dataset import convert_to_mindrecord
|
||||
from src.textrcnn import textrcnn
|
||||
from src.utils import get_lr
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config as cfg
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
set_seed(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
cfg.ckpt_folder_path = os.path.join(cfg.output_path, cfg.ckpt_folder_path)
|
||||
cfg.preprocess_path = cfg.data_path
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
'''train function.'''
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target="Ascend")
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_id = get_device_id()
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if cfg.preprocess == 'true':
|
||||
|
@ -48,7 +57,7 @@ if __name__ == '__main__':
|
|||
if os.path.exists(cfg.preprocess_path):
|
||||
shutil.rmtree(cfg.preprocess_path)
|
||||
os.mkdir(cfg.preprocess_path)
|
||||
convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path)
|
||||
convert_to_mindrecord(cfg.embed_size, cfg.data_root, cfg.preprocess_path, cfg.emb_path)
|
||||
|
||||
if cfg.cell == "vanilla":
|
||||
print("============ Precision is lower than expected when using vanilla RNN architecture ===========")
|
||||
|
@ -79,3 +88,6 @@ if __name__ == '__main__':
|
|||
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
|
||||
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb, time_cb])
|
||||
print("train success")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue