This commit is contained in:
郑彬 2021-06-10 15:23:41 +08:00
parent 50304d82fb
commit bced1e6d4e
15 changed files with 556 additions and 120 deletions

View File

@ -0,0 +1,60 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
# ==============================================================================
pos_dir: 'data/rt-polaritydata/rt-polarity.pos'
neg_dir: 'data/rt-polaritydata/rt-polarity.neg'
num_epochs: 10
lstm_num_epochs: 15
batch_size: 64
cell: 'gru'
ckpt_folder_path: './ckpt'
preprocess_path: './preprocess'
preprocess: 'false'
data_root: './data/'
lr: 0.001 # 1e-3
lstm_lr_init: 0.002 # 2e-3
lstm_lr_end: 0.0005 # 5e-4
lstm_lr_max: 0.003 # 3e-3
lstm_lr_warm_up_epochs: 2
lstm_lr_adjust_epochs: 9
emb_path: './word2vec'
embed_size: 300
save_checkpoint_steps: 149
keep_checkpoint_max: 10
ckpt_path: ''
# Export related
ckpt_file: ''
file_name: 'textrcnn'
file_format: "MINDIR"
---
# Help description for each configuration
# ModelArts related
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'
# Export related
ckpt_file: "textrcnn ckpt file."
file_name: "textrcnn output file name."
file_format: "file format, choose from MINDIR or AIR"
---
file_format: ["AIR", "MINDIR"]
device_target: ["Ascend"]

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""model evaluation script""" """model evaluation script"""
import os import os
import argparse
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
@ -23,25 +22,32 @@ from mindspore import Tensor
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import LossMonitor # from mindspore.train.callback import LossMonitor
from mindspore.common import set_seed from mindspore.common import set_seed
from src.config import textrcnn_cfg as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from src.textrcnn import textrcnn from src.textrcnn import textrcnn
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config as cfg
from src.model_utils.device_adapter import get_device_id
set_seed(1) set_seed(1)
if __name__ == '__main__': def modelarts_pre_process():
parser = argparse.ArgumentParser(description='textrcnn') '''modelarts pre process function.'''
parser.add_argument('--ckpt_path', type=str) cfg.ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg.ckpt_path)
args = parser.parse_args() cfg.preprocess_path = cfg.data_path
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
'''eval function.'''
context.set_context( context.set_context(
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
save_graphs=False, save_graphs=False,
device_target="Ascend") device_target="Ascend")
device_id = int(os.getenv('DEVICE_ID')) device_id = get_device_id()
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32)
@ -49,12 +55,15 @@ if __name__ == '__main__':
cell=cfg.cell, batch_size=cfg.batch_size) cell=cfg.cell, batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
eval_net = nn.WithEvalCell(network, loss, True) eval_net = nn.WithEvalCell(network, loss, True)
loss_cb = LossMonitor() # loss_cb = LossMonitor()
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False) ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False)
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(cfg.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
network.set_train(False) network.set_train(False)
model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2]) model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2])
acc = model.eval(ds_eval, dataset_sink_mode=False) acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc)) print("============== Accuracy:{} ==============".format(acc))
if __name__ == '__main__':
run_eval()

View File

@ -14,26 +14,16 @@
# ============================================================================ # ============================================================================
"""textrcnn export ckpt file to mindir/air""" """textrcnn export ckpt file to mindir/air"""
import os import os
import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.textrcnn import textrcnn from src.textrcnn import textrcnn
from src.config import textrcnn_cfg as config from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
parser = argparse.ArgumentParser(description="textrcnn") def run_export():
parser.add_argument("--device_id", type=int, default=0, help="Device id") '''export function.'''
parser.add_argument("--ckpt_file", type=str, required=True, help="textrcnn ckpt file.") context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
parser.add_argument("--file_name", type=str, default="textrcnn", help="textrcnn output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"],
default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if __name__ == "__main__":
# define net # define net
embedding_table = np.loadtxt(os.path.join(config.preprocess_path, "weight.txt")).astype(np.float32) embedding_table = np.loadtxt(os.path.join(config.preprocess_path, "weight.txt")).astype(np.float32)
@ -41,9 +31,12 @@ if __name__ == "__main__":
cell=config.cell, batch_size=config.batch_size) cell=config.cell, batch_size=config.batch_size)
# load checkpoint # load checkpoint
param_dict = load_checkpoint(args.ckpt_file) param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
image = Tensor(np.ones([config.batch_size, 50], np.int32)) image = Tensor(np.ones([config.batch_size, 50], np.int32))
export(net, image, file_name=args.file_name, file_format=args.file_format) export(net, image, file_name=config.file_name, file_format=config.file_format)
if __name__ == "__main__":
run_export()

View File

@ -18,7 +18,7 @@ import argparse
import numpy as np import numpy as np
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.config import textrcnn_cfg as cfg from src.model_utils.config import config as cfg
parser = argparse.ArgumentParser(description='postprocess') parser = argparse.ArgumentParser(description='postprocess')
parser.add_argument('--label_path', type=str, default="./preprocess_Result/label_ids.npy") parser.add_argument('--label_path', type=str, default="./preprocess_Result/label_ids.npy")

View File

@ -17,7 +17,7 @@ import os
import argparse import argparse
import numpy as np import numpy as np
from src.config import textrcnn_cfg as cfg from src.model_utils.config import config as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
parser = argparse.ArgumentParser(description='preprocess') parser = argparse.ArgumentParser(description='preprocess')

View File

@ -65,6 +65,8 @@ Dataset used: [Sentence polarity dataset v1.0](http://www.cs.cornell.edu/people/
- Running on Ascend - Running on Ascend
If you are running the scripts for the first time and , you must set the parameter 'preprocess' to 'true' in the `default_config.yaml` and run training to get the folder 'preprocess' containing data。
```python ```python
# run training # run training
DEVICE_ID=7 python train.py DEVICE_ID=7 python train.py
@ -77,6 +79,58 @@ DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path}
bash scripts/run_eval.sh bash scripts/run_eval.sh
``` ```
- Running on ModelArts
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
You have to prepare the folder 'preprocess'.
You can change the file name of 'requirements.txt' to 'pip-requirements.txt' for installing some third party libraries automatically on ModelArts.
- Training standalone on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/textrcnn" on the website UI interface.
# (4) Set the startup file to /{path}/textrcnn/train.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/textrcnn/default_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set ”cell: 'lstm'“(Default is 'gru'. if you want to use lstm, you can do this step)
# b. adding on the website UI interface.
# 1. Set ”enable_modelarts=True“
# 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step)
# (6) Upload the dataset(the folder 'preprocess') to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of a single cards.
# (10) Create your job.
```
- evaluating with single card on ModelArts
```python
# (1) Upload the code folder to S3 bucket.
# (2) Click to "create training task" on the website UI interface.
# (3) Set the code directory to "/{path}/textrcnn" on the website UI interface.
# (4) Set the startup file to /{path}/textrcnn/eval.py" on the website UI interface.
# (5) Perform a or b.
# a. setting parameters in /{path}/textrcnn/default_config.yaml.
# 1. Set ”enable_modelarts: True“
# 2. Set ”cell: 'lstm'“(Default is 'gru'. If you want to use lstm, you can do this step)
# 3. Set ”ckpt_path: './{path}/*.ckpt'“(The *.ckpt file must under the folder 'textrcnn')
# b. adding on the website UI interface.
# 1. Set ”enable_modelarts=True“
# 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step)
# 3. Set ”ckpt_path=./{path}/*.ckpt“(The *.ckpt file must under the folder 'textrcnn')
# (6) Upload the dataset(the folder 'preprocess') to S3 bucket.
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
# (9) Under the item "resource pool selection", select the specification of a single cards.
# (10) Create your job.
```
## [Script Description](#contents) ## [Script Description](#contents)
### [Script and Sample Code](#contents) ### [Script and Sample Code](#contents)
@ -85,63 +139,61 @@ bash scripts/run_eval.sh
├── model_zoo ├── model_zoo
├── README.md // descriptions about all the models ├── README.md // descriptions about all the models
├── textrcnn ├── textrcnn
├── README.md // descriptions about TextRCNN ├── readme.md // descriptions about TextRCNN
├── data_src ├── ascend310_infer // application for 310 inference
│ ├──rt-polaritydata // directory to save the source data
│ ├──rt-polaritydata.README.1.0.txt // readme file of dataset
├── scripts ├── scripts
│ ├──run_train.sh // shell script for train on Ascend │ ├──run_train.sh // shell script for train on Ascend
│ ├──run_eval.sh // shell script for evaluation on Ascend │ ├──run_infer_310.sh // shell script for 310 infer
├──sample.txt // example shell to run the above the two scripts └──run_eval.sh // shell script for evaluation on Ascend
├── src ├── src
│ ├──model_utils
│ │ ├──config.py // parsing parameter configuration file of "*.yaml"
│ │ ├──device_adapter.py // local or ModelArts training
│ │ ├──local_adapter.py // get related environment variables in local training
│ │ └──moxing_adapter.py // get related environment variables in ModelArts training
│ ├──dataset.py // creating dataset │ ├──dataset.py // creating dataset
│ ├──textrcnn.py // textrcnn architecture │ ├──textrcnn.py // textrcnn architecture
│ ├──config.py // parameter configuration │ └──utils.py // function related to learning rate
├── train.py // training script
├── export.py // export script
├── eval.py // evaluation script
├── data_helpers.py // dataset split script ├── data_helpers.py // dataset split script
├── sample.txt // the shell to train and eval the model without scripts ├── default_config.yaml // parameter configuration
├── eval.py // evaluation script
├── export.py // export script
├── mindspore_hub_conf.py // mindspore hub interface
├── postprocess.py // 310infer postprocess script
├── preprocess.py // dataset generation script
├── requirements.txt // some third party libraries that need to be installed
├── sample.txt // the shell to train and eval the model without '*.sh'
└── train.py // training script
``` ```
### [Script Parameters](#contents) ### [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py Parameters for both training and evaluation can be set in `default_config.yaml`
- config for Textrcnn, Sentence polarity dataset v1.0. - config for Textrcnn, Sentence polarity dataset v1.0.
```python ```python
'num_epochs': 10, # total training epochs num_epochs: 10 # total training epochs
'lstm_num_epochs': 15, # total training epochs when using lstm lstm_num_epochs: 15 # total training epochs when using lstm
'batch_size': 64, # training batch size batch_size: 64 # training batch size
'cell': 'gru', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'. cell: 'gru' # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'.
'ckpt_folder_path': './ckpt', # the path to save the checkpoints ckpt_folder_path: './ckpt' # the path to save the checkpoints
'preprocess_path': './preprocess', # the directory to save the processed data preprocess_path: './preprocess' # the directory to save the processed data
'preprocess' : 'false', # whethere to preprocess the data preprocess: 'false' # whethere to preprocess the data
'data_path': './data/', # the path to store the splited data data_path: './data/' # the path to store the splited data
'lr': 1e-3, # the training learning rate lr: 0.001 # 1e-3 # the training learning rate
'lstm_lr_init': 2e-3, # learning rate initial value when using lstm lstm_lr_init: 0.002 # 2e-3 # learning rate initial value when using lstm
'lstm_lr_end': 5e-4, # learning rate end value when using lstm lstm_lr_end: 0.0005 # 5e-4 # learning rate end value when using lstm
'lstm_lr_max': 3e-3, # learning eate max value when using lstm lstm_lr_max: 0.003 # 3e-3 # learning eate max value when using lstm
'lstm_lr_warm_up_epochs': 2 # warm up epoch num when using lstm lstm_lr_warm_up_epochs: 2 # warm up epoch num when using lstm
'lstm_lr_adjust_epochs': 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm lstm_lr_adjust_epochs: 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm
'emb_path': './word2vec', # the directory to save the embedding file emb_path: './word2vec' # the directory to save the embedding file
'embed_size': 300, # the dimension of the word embedding embed_size: 300 # the dimension of the word embedding
'save_checkpoint_steps': 149, # per step to save the checkpoint save_checkpoint_steps: 149 # per step to save the checkpoint
'keep_checkpoint_max': 10 # max checkpoints to save keep_checkpoint_max: 10 # max checkpoints to save
ckpt_path: '' # relative path of '*.ckpt' to be evaluated relative to the eval.py
``` ```
### Performance
| Model | MindSpore + Ascend | TensorFlow+GPU |
| -------------------------- | ----------------------------- | ------------------------- |
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
| Version | 1.0.1 | 1.4.0 |
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
| batch_size | 64 | 64 |
| Accuracy | 0.78 | 0.78 |
| Speed | 35ms/step | 77ms/step |
## Inference Process ## Inference Process
### [Export MindIR](#contents) ### [Export MindIR](#contents)
@ -173,6 +225,17 @@ Inference result is saved in current path, you can find result like this in acc.
============== Accuracy:{} ============== 0.8008 ============== Accuracy:{} ============== 0.8008
``` ```
### Performance
| Model | MindSpore + Ascend | TensorFlow+GPU |
| -------------------------- | ----------------------------- | ------------------------- |
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G |
| Version | 1.0.1 | 1.4.0 |
| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 |
| batch_size | 64 | 64 |
| Accuracy | 0.78 | 0.78 |
| Speed | 35ms/step | 77ms/step |
## [ModelZoo Homepage](#contents) ## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -13,6 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 2 ] && [ $# != 1 ]; then
echo "Usage: sh run_eval.sh [CKPT_PATH] [DEVICE_ID]"
exit 1
fi
if [ $# == 2 ]; then
export DEVICE_ID=$2
else
export DEVICE_ID=0
fi
ulimit -u unlimited ulimit -u unlimited
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)

View File

@ -13,6 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 1 ] && [ $# != 0 ]; then
echo "Usage: sh run_train.sh [DEVICE_ID]"
exit 1
fi
if [ $# == 1 ]; then
export DEVICE_ID=$1
else
export DEVICE_ID=0
fi
ulimit -u unlimited ulimit -u unlimited

View File

@ -1,42 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config
"""
from easydict import EasyDict as edict
# LSTM CONFIG
textrcnn_cfg = edict({
'pos_dir': 'data/rt-polaritydata/rt-polarity.pos',
'neg_dir': 'data/rt-polaritydata/rt-polarity.neg',
'num_epochs': 10,
'lstm_num_epochs': 15,
'batch_size': 64,
'cell': 'gru',
'ckpt_folder_path': './ckpt',
'preprocess_path': './preprocess',
'preprocess': 'false',
'data_path': './data/',
'lr': 1e-3,
'lstm_lr_init': 2e-3,
'lstm_lr_end': 5e-4,
'lstm_lr_max': 3e-3,
'lstm_lr_warm_up_epochs': 2,
'lstm_lr_adjust_epochs': 9,
'emb_path': './word2vec',
'embed_size': 300,
'save_checkpoint_steps': 149,
'keep_checkpoint_max': 10,
})

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()
if __name__ == '__main__':
print(config)

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from src.model_utils.config import config
if config.enable_modelarts:
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,123 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from src.model_utils.config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
# print("os.mknod({}) success".format(sync_lock))
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -63,6 +63,10 @@ class textrcnn(nn.Cell):
self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens) self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens) self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens) self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
self.rnnW_fw.to_float(mstype.float16)
self.rnnU_fw.to_float(mstype.float16)
self.rnnW_bw.to_float(mstype.float16)
self.rnnU_bw.to_float(mstype.float16)
if cell == "gru": if cell == "gru":
self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)

View File

@ -25,22 +25,31 @@ from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import set_seed from mindspore.common import set_seed
from src.config import textrcnn_cfg as cfg
from src.dataset import create_dataset from src.dataset import create_dataset
from src.dataset import convert_to_mindrecord from src.dataset import convert_to_mindrecord
from src.textrcnn import textrcnn from src.textrcnn import textrcnn
from src.utils import get_lr from src.utils import get_lr
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config as cfg
from src.model_utils.device_adapter import get_device_id
set_seed(0) set_seed(0)
if __name__ == '__main__': def modelarts_pre_process():
'''modelarts pre process function.'''
cfg.ckpt_folder_path = os.path.join(cfg.output_path, cfg.ckpt_folder_path)
cfg.preprocess_path = cfg.data_path
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''train function.'''
context.set_context( context.set_context(
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
save_graphs=False, save_graphs=False,
device_target="Ascend") device_target="Ascend")
device_id = int(os.getenv('DEVICE_ID')) device_id = get_device_id()
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
if cfg.preprocess == 'true': if cfg.preprocess == 'true':
@ -48,7 +57,7 @@ if __name__ == '__main__':
if os.path.exists(cfg.preprocess_path): if os.path.exists(cfg.preprocess_path):
shutil.rmtree(cfg.preprocess_path) shutil.rmtree(cfg.preprocess_path)
os.mkdir(cfg.preprocess_path) os.mkdir(cfg.preprocess_path)
convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path) convert_to_mindrecord(cfg.embed_size, cfg.data_root, cfg.preprocess_path, cfg.emb_path)
if cfg.cell == "vanilla": if cfg.cell == "vanilla":
print("============ Precision is lower than expected when using vanilla RNN architecture ===========") print("============ Precision is lower than expected when using vanilla RNN architecture ===========")
@ -79,3 +88,6 @@ if __name__ == '__main__':
ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck)
model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb, time_cb]) model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb, time_cb])
print("train success") print("train success")
if __name__ == '__main__':
run_train()