From bced1e6d4ea5ed44473dd68d7b0b8b85e88fa989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=BD=AC?= Date: Thu, 10 Jun 2021 15:23:41 +0800 Subject: [PATCH] textrcnn --- .../research/nlp/textrcnn/default_config.yaml | 60 +++++++ model_zoo/research/nlp/textrcnn/eval.py | 29 ++-- model_zoo/research/nlp/textrcnn/export.py | 27 ++-- .../research/nlp/textrcnn/postprocess.py | 2 +- model_zoo/research/nlp/textrcnn/preprocess.py | 2 +- model_zoo/research/nlp/textrcnn/readme.md | 153 ++++++++++++------ .../research/nlp/textrcnn/scripts/run_eval.sh | 11 ++ .../nlp/textrcnn/scripts/run_train.sh | 10 ++ model_zoo/research/nlp/textrcnn/src/config.py | 42 ----- .../nlp/textrcnn/src/model_utils/config.py | 130 +++++++++++++++ .../src/model_utils/device_adapter.py | 27 ++++ .../textrcnn/src/model_utils/local_adapter.py | 36 +++++ .../src/model_utils/moxing_adapter.py | 123 ++++++++++++++ .../research/nlp/textrcnn/src/textrcnn.py | 4 + model_zoo/research/nlp/textrcnn/train.py | 20 ++- 15 files changed, 556 insertions(+), 120 deletions(-) create mode 100644 model_zoo/research/nlp/textrcnn/default_config.yaml delete mode 100644 model_zoo/research/nlp/textrcnn/src/config.py create mode 100644 model_zoo/research/nlp/textrcnn/src/model_utils/config.py create mode 100644 model_zoo/research/nlp/textrcnn/src/model_utils/device_adapter.py create mode 100644 model_zoo/research/nlp/textrcnn/src/model_utils/local_adapter.py create mode 100644 model_zoo/research/nlp/textrcnn/src/model_utils/moxing_adapter.py diff --git a/model_zoo/research/nlp/textrcnn/default_config.yaml b/model_zoo/research/nlp/textrcnn/default_config.yaml new file mode 100644 index 00000000000..665cff9a934 --- /dev/null +++ b/model_zoo/research/nlp/textrcnn/default_config.yaml @@ -0,0 +1,60 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +enable_profiling: False + +# ============================================================================== +pos_dir: 'data/rt-polaritydata/rt-polarity.pos' +neg_dir: 'data/rt-polaritydata/rt-polarity.neg' +num_epochs: 10 +lstm_num_epochs: 15 +batch_size: 64 +cell: 'gru' +ckpt_folder_path: './ckpt' +preprocess_path: './preprocess' +preprocess: 'false' +data_root: './data/' +lr: 0.001 # 1e-3 +lstm_lr_init: 0.002 # 2e-3 +lstm_lr_end: 0.0005 # 5e-4 +lstm_lr_max: 0.003 # 3e-3 +lstm_lr_warm_up_epochs: 2 +lstm_lr_adjust_epochs: 9 +emb_path: './word2vec' +embed_size: 300 +save_checkpoint_steps: 149 +keep_checkpoint_max: 10 +ckpt_path: '' + +# Export related +ckpt_file: '' +file_name: 'textrcnn' +file_format: "MINDIR" + +--- + +# Help description for each configuration +# ModelArts related +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of the input data." +output_path: "The location of the output file." +device_target: "Running platform, choose from Ascend, GPU or CPU, and default is Ascend." +enable_profiling: 'Whether enable profiling while training, default: False' +# Export related +ckpt_file: "textrcnn ckpt file." +file_name: "textrcnn output file name." +file_format: "file format, choose from MINDIR or AIR" + +--- +file_format: ["AIR", "MINDIR"] +device_target: ["Ascend"] \ No newline at end of file diff --git a/model_zoo/research/nlp/textrcnn/eval.py b/model_zoo/research/nlp/textrcnn/eval.py index cb36f70891a..d6b980fe524 100644 --- a/model_zoo/research/nlp/textrcnn/eval.py +++ b/model_zoo/research/nlp/textrcnn/eval.py @@ -14,7 +14,6 @@ # ============================================================================ """model evaluation script""" import os -import argparse import numpy as np import mindspore.nn as nn @@ -23,25 +22,32 @@ from mindspore import Tensor from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import LossMonitor +# from mindspore.train.callback import LossMonitor from mindspore.common import set_seed -from src.config import textrcnn_cfg as cfg from src.dataset import create_dataset from src.textrcnn import textrcnn +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.config import config as cfg +from src.model_utils.device_adapter import get_device_id set_seed(1) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='textrcnn') - parser.add_argument('--ckpt_path', type=str) - args = parser.parse_args() +def modelarts_pre_process(): + '''modelarts pre process function.''' + cfg.ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg.ckpt_path) + cfg.preprocess_path = cfg.data_path + + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_eval(): + '''eval function.''' context.set_context( mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend") - device_id = int(os.getenv('DEVICE_ID')) + device_id = get_device_id() context.set_context(device_id=device_id) embedding_table = np.loadtxt(os.path.join(cfg.preprocess_path, "weight.txt")).astype(np.float32) @@ -49,12 +55,15 @@ if __name__ == '__main__': cell=cfg.cell, batch_size=cfg.batch_size) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) eval_net = nn.WithEvalCell(network, loss, True) - loss_cb = LossMonitor() + # loss_cb = LossMonitor() print("============== Starting Testing ==============") ds_eval = create_dataset(cfg.preprocess_path, cfg.batch_size, False) - param_dict = load_checkpoint(args.ckpt_path) + param_dict = load_checkpoint(cfg.ckpt_path) load_param_into_net(network, param_dict) network.set_train(False) model = Model(network, loss, metrics={'acc': Accuracy()}, eval_network=eval_net, eval_indexes=[0, 1, 2]) acc = model.eval(ds_eval, dataset_sink_mode=False) print("============== Accuracy:{} ==============".format(acc)) + +if __name__ == '__main__': + run_eval() diff --git a/model_zoo/research/nlp/textrcnn/export.py b/model_zoo/research/nlp/textrcnn/export.py index 36e52fad9c5..2f96a179327 100644 --- a/model_zoo/research/nlp/textrcnn/export.py +++ b/model_zoo/research/nlp/textrcnn/export.py @@ -14,26 +14,16 @@ # ============================================================================ """textrcnn export ckpt file to mindir/air""" import os -import argparse import numpy as np from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from src.textrcnn import textrcnn -from src.config import textrcnn_cfg as config +from src.model_utils.config import config +from src.model_utils.device_adapter import get_device_id -parser = argparse.ArgumentParser(description="textrcnn") -parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--ckpt_file", type=str, required=True, help="textrcnn ckpt file.") -parser.add_argument("--file_name", type=str, default="textrcnn", help="textrcnn output file name.") -parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], - default="MINDIR", help="file format") -parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend", - help="device target") -args = parser.parse_args() - -context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) - -if __name__ == "__main__": +def run_export(): + '''export function.''' + context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id()) # define net embedding_table = np.loadtxt(os.path.join(config.preprocess_path, "weight.txt")).astype(np.float32) @@ -41,9 +31,12 @@ if __name__ == "__main__": cell=config.cell, batch_size=config.batch_size) # load checkpoint - param_dict = load_checkpoint(args.ckpt_file) + param_dict = load_checkpoint(config.ckpt_file) load_param_into_net(net, param_dict) net.set_train(False) image = Tensor(np.ones([config.batch_size, 50], np.int32)) - export(net, image, file_name=args.file_name, file_format=args.file_format) + export(net, image, file_name=config.file_name, file_format=config.file_format) + +if __name__ == "__main__": + run_export() diff --git a/model_zoo/research/nlp/textrcnn/postprocess.py b/model_zoo/research/nlp/textrcnn/postprocess.py index cddd58acab6..6098a7397f1 100644 --- a/model_zoo/research/nlp/textrcnn/postprocess.py +++ b/model_zoo/research/nlp/textrcnn/postprocess.py @@ -18,7 +18,7 @@ import argparse import numpy as np from mindspore.nn.metrics import Accuracy -from src.config import textrcnn_cfg as cfg +from src.model_utils.config import config as cfg parser = argparse.ArgumentParser(description='postprocess') parser.add_argument('--label_path', type=str, default="./preprocess_Result/label_ids.npy") diff --git a/model_zoo/research/nlp/textrcnn/preprocess.py b/model_zoo/research/nlp/textrcnn/preprocess.py index 9ac666b9899..3cac94cd412 100644 --- a/model_zoo/research/nlp/textrcnn/preprocess.py +++ b/model_zoo/research/nlp/textrcnn/preprocess.py @@ -17,7 +17,7 @@ import os import argparse import numpy as np -from src.config import textrcnn_cfg as cfg +from src.model_utils.config import config as cfg from src.dataset import create_dataset parser = argparse.ArgumentParser(description='preprocess') diff --git a/model_zoo/research/nlp/textrcnn/readme.md b/model_zoo/research/nlp/textrcnn/readme.md index 2ee5b60e5ae..2734a1ae2cd 100644 --- a/model_zoo/research/nlp/textrcnn/readme.md +++ b/model_zoo/research/nlp/textrcnn/readme.md @@ -65,6 +65,8 @@ Dataset used: [Sentence polarity dataset v1.0](http://www.cs.cornell.edu/people/ - Running on Ascend + If you are running the scripts for the first time and , you must set the parameter 'preprocess' to 'true' in the `default_config.yaml` and run training to get the folder 'preprocess' containing data。 + ```python # run training DEVICE_ID=7 python train.py @@ -77,71 +79,121 @@ DEVICE_ID=7 python eval.py --ckpt_path {checkpoint path} bash scripts/run_eval.sh ``` +- Running on ModelArts + + If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows + + You have to prepare the folder 'preprocess'. + + You can change the file name of 'requirements.txt' to 'pip-requirements.txt' for installing some third party libraries automatically on ModelArts. + + - Training standalone on ModelArts + + ```python + # (1) Upload the code folder to S3 bucket. + # (2) Click to "create training task" on the website UI interface. + # (3) Set the code directory to "/{path}/textrcnn" on the website UI interface. + # (4) Set the startup file to /{path}/textrcnn/train.py" on the website UI interface. + # (5) Perform a or b. + # a. setting parameters in /{path}/textrcnn/default_config.yaml. + # 1. Set ”enable_modelarts: True“ + # 2. Set ”cell: 'lstm'“(Default is 'gru'. if you want to use lstm, you can do this step) + # b. adding on the website UI interface. + # 1. Set ”enable_modelarts=True“ + # 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step) + # (6) Upload the dataset(the folder 'preprocess') to S3 bucket. + # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path. + # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. + # (9) Under the item "resource pool selection", select the specification of a single cards. + # (10) Create your job. + ``` + + - evaluating with single card on ModelArts + + ```python + # (1) Upload the code folder to S3 bucket. + # (2) Click to "create training task" on the website UI interface. + # (3) Set the code directory to "/{path}/textrcnn" on the website UI interface. + # (4) Set the startup file to /{path}/textrcnn/eval.py" on the website UI interface. + # (5) Perform a or b. + # a. setting parameters in /{path}/textrcnn/default_config.yaml. + # 1. Set ”enable_modelarts: True“ + # 2. Set ”cell: 'lstm'“(Default is 'gru'. If you want to use lstm, you can do this step) + # 3. Set ”ckpt_path: './{path}/*.ckpt'“(The *.ckpt file must under the folder 'textrcnn') + # b. adding on the website UI interface. + # 1. Set ”enable_modelarts=True“ + # 2. Set ”cell=lstm“(Default is 'gru'. if you want to use lstm, you can do this step) + # 3. Set ”ckpt_path=./{path}/*.ckpt“(The *.ckpt file must under the folder 'textrcnn') + # (6) Upload the dataset(the folder 'preprocess') to S3 bucket. + # (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path. + # (8) Set the "Output file path" and "Job log path" to your path on the website UI interface. + # (9) Under the item "resource pool selection", select the specification of a single cards. + # (10) Create your job. + ``` + ## [Script Description](#contents) ### [Script and Sample Code](#contents) ```python ├── model_zoo - ├── README.md // descriptions about all the models + ├── README.md // descriptions about all the models ├── textrcnn - ├── README.md // descriptions about TextRCNN - ├── data_src - │ ├──rt-polaritydata // directory to save the source data - │ ├──rt-polaritydata.README.1.0.txt // readme file of dataset + ├── readme.md // descriptions about TextRCNN + ├── ascend310_infer // application for 310 inference ├── scripts │ ├──run_train.sh // shell script for train on Ascend - │ ├──run_eval.sh // shell script for evaluation on Ascend - │ ├──sample.txt // example shell to run the above the two scripts + │ ├──run_infer_310.sh // shell script for 310 infer + │ └──run_eval.sh // shell script for evaluation on Ascend ├── src - │ ├──dataset.py // creating dataset - │ ├──textrcnn.py // textrcnn architecture - │ ├──config.py // parameter configuration - ├── train.py // training script - ├── export.py // export script - ├── eval.py // evaluation script - ├── data_helpers.py // dataset split script - ├── sample.txt // the shell to train and eval the model without scripts + │ ├──model_utils + │ │ ├──config.py // parsing parameter configuration file of "*.yaml" + │ │ ├──device_adapter.py // local or ModelArts training + │ │ ├──local_adapter.py // get related environment variables in local training + │ │ └──moxing_adapter.py // get related environment variables in ModelArts training + │ ├──dataset.py // creating dataset + │ ├──textrcnn.py // textrcnn architecture + │ └──utils.py // function related to learning rate + ├── data_helpers.py // dataset split script + ├── default_config.yaml // parameter configuration + ├── eval.py // evaluation script + ├── export.py // export script + ├── mindspore_hub_conf.py // mindspore hub interface + ├── postprocess.py // 310infer postprocess script + ├── preprocess.py // dataset generation script + ├── requirements.txt // some third party libraries that need to be installed + ├── sample.txt // the shell to train and eval the model without '*.sh' + └── train.py // training script ``` ### [Script Parameters](#contents) -Parameters for both training and evaluation can be set in config.py +Parameters for both training and evaluation can be set in `default_config.yaml` - config for Textrcnn, Sentence polarity dataset v1.0. ```python - 'num_epochs': 10, # total training epochs - 'lstm_num_epochs': 15, # total training epochs when using lstm - 'batch_size': 64, # training batch size - 'cell': 'gru', # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'. - 'ckpt_folder_path': './ckpt', # the path to save the checkpoints - 'preprocess_path': './preprocess', # the directory to save the processed data - 'preprocess' : 'false', # whethere to preprocess the data - 'data_path': './data/', # the path to store the splited data - 'lr': 1e-3, # the training learning rate - 'lstm_lr_init': 2e-3, # learning rate initial value when using lstm - 'lstm_lr_end': 5e-4, # learning rate end value when using lstm - 'lstm_lr_max': 3e-3, # learning eate max value when using lstm - 'lstm_lr_warm_up_epochs': 2 # warm up epoch num when using lstm - 'lstm_lr_adjust_epochs': 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm - 'emb_path': './word2vec', # the directory to save the embedding file - 'embed_size': 300, # the dimension of the word embedding - 'save_checkpoint_steps': 149, # per step to save the checkpoint - 'keep_checkpoint_max': 10 # max checkpoints to save + num_epochs: 10 # total training epochs + lstm_num_epochs: 15 # total training epochs when using lstm + batch_size: 64 # training batch size + cell: 'gru' # the RNN architecture, can be 'vanilla', 'gru' and 'lstm'. + ckpt_folder_path: './ckpt' # the path to save the checkpoints + preprocess_path: './preprocess' # the directory to save the processed data + preprocess: 'false' # whethere to preprocess the data + data_path: './data/' # the path to store the splited data + lr: 0.001 # 1e-3 # the training learning rate + lstm_lr_init: 0.002 # 2e-3 # learning rate initial value when using lstm + lstm_lr_end: 0.0005 # 5e-4 # learning rate end value when using lstm + lstm_lr_max: 0.003 # 3e-3 # learning eate max value when using lstm + lstm_lr_warm_up_epochs: 2 # warm up epoch num when using lstm + lstm_lr_adjust_epochs: 9 # lr adjust in lr_adjust_epoch, after that, the lr is lr_end when using lstm + emb_path: './word2vec' # the directory to save the embedding file + embed_size: 300 # the dimension of the word embedding + save_checkpoint_steps: 149 # per step to save the checkpoint + keep_checkpoint_max: 10 # max checkpoints to save + ckpt_path: '' # relative path of '*.ckpt' to be evaluated relative to the eval.py ``` -### Performance - -| Model | MindSpore + Ascend | TensorFlow+GPU | -| -------------------------- | ----------------------------- | ------------------------- | -| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | -| Version | 1.0.1 | 1.4.0 | -| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 | -| batch_size | 64 | 64 | -| Accuracy | 0.78 | 0.78 | -| Speed | 35ms/step | 77ms/step | - ## Inference Process ### [Export MindIR](#contents) @@ -173,6 +225,17 @@ Inference result is saved in current path, you can find result like this in acc. ============== Accuracy:{} ============== 0.8008 ``` +### Performance + +| Model | MindSpore + Ascend | TensorFlow+GPU | +| -------------------------- | ----------------------------- | ------------------------- | +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G | +| Version | 1.0.1 | 1.4.0 | +| Dataset | Sentence polarity dataset v1.0 | Sentence polarity dataset v1.0 | +| batch_size | 64 | 64 | +| Accuracy | 0.78 | 0.78 | +| Speed | 35ms/step | 77ms/step | + ## [ModelZoo Homepage](#contents) Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/research/nlp/textrcnn/scripts/run_eval.sh b/model_zoo/research/nlp/textrcnn/scripts/run_eval.sh index 519f4273472..1bc9e51eaa9 100644 --- a/model_zoo/research/nlp/textrcnn/scripts/run_eval.sh +++ b/model_zoo/research/nlp/textrcnn/scripts/run_eval.sh @@ -13,6 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +if [ $# != 2 ] && [ $# != 1 ]; then + echo "Usage: sh run_eval.sh [CKPT_PATH] [DEVICE_ID]" + exit 1 +fi + +if [ $# == 2 ]; then + export DEVICE_ID=$2 +else + export DEVICE_ID=0 +fi + ulimit -u unlimited BASEPATH=$(cd "`dirname $0`" || exit; pwd) diff --git a/model_zoo/research/nlp/textrcnn/scripts/run_train.sh b/model_zoo/research/nlp/textrcnn/scripts/run_train.sh index 5e87829a5bf..dc7bf436734 100644 --- a/model_zoo/research/nlp/textrcnn/scripts/run_train.sh +++ b/model_zoo/research/nlp/textrcnn/scripts/run_train.sh @@ -13,6 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +if [ $# != 1 ] && [ $# != 0 ]; then + echo "Usage: sh run_train.sh [DEVICE_ID]" + exit 1 +fi + +if [ $# == 1 ]; then + export DEVICE_ID=$1 +else + export DEVICE_ID=0 +fi ulimit -u unlimited diff --git a/model_zoo/research/nlp/textrcnn/src/config.py b/model_zoo/research/nlp/textrcnn/src/config.py deleted file mode 100644 index 5f105bdbb80..00000000000 --- a/model_zoo/research/nlp/textrcnn/src/config.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -network config -""" -from easydict import EasyDict as edict - -# LSTM CONFIG -textrcnn_cfg = edict({ - 'pos_dir': 'data/rt-polaritydata/rt-polarity.pos', - 'neg_dir': 'data/rt-polaritydata/rt-polarity.neg', - 'num_epochs': 10, - 'lstm_num_epochs': 15, - 'batch_size': 64, - 'cell': 'gru', - 'ckpt_folder_path': './ckpt', - 'preprocess_path': './preprocess', - 'preprocess': 'false', - 'data_path': './data/', - 'lr': 1e-3, - 'lstm_lr_init': 2e-3, - 'lstm_lr_end': 5e-4, - 'lstm_lr_max': 3e-3, - 'lstm_lr_warm_up_epochs': 2, - 'lstm_lr_adjust_epochs': 9, - 'emb_path': './word2vec', - 'embed_size': 300, - 'save_checkpoint_steps': 149, - 'keep_checkpoint_max': 10, -}) diff --git a/model_zoo/research/nlp/textrcnn/src/model_utils/config.py b/model_zoo/research/nlp/textrcnn/src/model_utils/config.py new file mode 100644 index 00000000000..f5120bb8299 --- /dev/null +++ b/model_zoo/research/nlp/textrcnn/src/model_utils/config.py @@ -0,0 +1,130 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse arguments""" + +import os +import ast +import argparse +from pprint import pprint, pformat +import yaml + +class Config: + """ + Configuration namespace. Convert dictionary to members. + """ + def __init__(self, cfg_dict): + for k, v in cfg_dict.items(): + if isinstance(v, (list, tuple)): + setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v]) + else: + setattr(self, k, Config(v) if isinstance(v, dict) else v) + + def __str__(self): + return pformat(self.__dict__) + + def __repr__(self): + return self.__str__() + + +def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"): + """ + Parse command line arguments to the configuration according to the default yaml. + + Args: + parser: Parent parser. + cfg: Base configuration. + helper: Helper description. + cfg_path: Path to the default yaml config. + """ + parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]", + parents=[parser]) + helper = {} if helper is None else helper + choices = {} if choices is None else choices + for item in cfg: + if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict): + help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path) + choice = choices[item] if item in choices else None + if isinstance(cfg[item], bool): + parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice, + help=help_description) + else: + parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice, + help=help_description) + args = parser.parse_args() + return args + + +def parse_yaml(yaml_path): + """ + Parse the yaml config file. + + Args: + yaml_path: Path to the yaml config. + """ + with open(yaml_path, 'r') as fin: + try: + cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader) + cfgs = [x for x in cfgs] + if len(cfgs) == 1: + cfg_helper = {} + cfg = cfgs[0] + cfg_choices = {} + elif len(cfgs) == 2: + cfg, cfg_helper = cfgs + cfg_choices = {} + elif len(cfgs) == 3: + cfg, cfg_helper, cfg_choices = cfgs + else: + raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml") + print(cfg_helper) + except: + raise ValueError("Failed to parse yaml") + return cfg, cfg_helper, cfg_choices + + +def merge(args, cfg): + """ + Merge the base config from yaml file and command line arguments. + + Args: + args: Command line arguments. + cfg: Base configuration. + """ + args_var = vars(args) + for item in args_var: + cfg[item] = args_var[item] + return cfg + + +def get_config(): + """ + Get Config according to the yaml file and cli arguments. + """ + parser = argparse.ArgumentParser(description="default name", add_help=False) + current_dir = os.path.dirname(os.path.abspath(__file__)) + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), + help="Config file path") + path_args, _ = parser.parse_known_args() + default, helper, choices = parse_yaml(path_args.config_path) + pprint(default) + args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path) + final_config = merge(args, default) + return Config(final_config) + +config = get_config() + +if __name__ == '__main__': + print(config) diff --git a/model_zoo/research/nlp/textrcnn/src/model_utils/device_adapter.py b/model_zoo/research/nlp/textrcnn/src/model_utils/device_adapter.py new file mode 100644 index 00000000000..9c3d21d5e47 --- /dev/null +++ b/model_zoo/research/nlp/textrcnn/src/model_utils/device_adapter.py @@ -0,0 +1,27 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Device adapter for ModelArts""" + +from src.model_utils.config import config + +if config.enable_modelarts: + from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id +else: + from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id + +__all__ = [ + "get_device_id", "get_device_num", "get_rank_id", "get_job_id" +] diff --git a/model_zoo/research/nlp/textrcnn/src/model_utils/local_adapter.py b/model_zoo/research/nlp/textrcnn/src/model_utils/local_adapter.py new file mode 100644 index 00000000000..769fa6dc78e --- /dev/null +++ b/model_zoo/research/nlp/textrcnn/src/model_utils/local_adapter.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Local adapter""" + +import os + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + return "Local Job" diff --git a/model_zoo/research/nlp/textrcnn/src/model_utils/moxing_adapter.py b/model_zoo/research/nlp/textrcnn/src/model_utils/moxing_adapter.py new file mode 100644 index 00000000000..09cb0f0cf0f --- /dev/null +++ b/model_zoo/research/nlp/textrcnn/src/model_utils/moxing_adapter.py @@ -0,0 +1,123 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Moxing adapter for ModelArts""" + +import os +import functools +from mindspore import context +from mindspore.profiler import Profiler +from src.model_utils.config import config + +_global_sync_count = 0 + +def get_device_id(): + device_id = os.getenv('DEVICE_ID', '0') + return int(device_id) + + +def get_device_num(): + device_num = os.getenv('RANK_SIZE', '1') + return int(device_num) + + +def get_rank_id(): + global_rank_id = os.getenv('RANK_ID', '0') + return int(global_rank_id) + + +def get_job_id(): + job_id = os.getenv('JOB_ID') + job_id = job_id if job_id != "" else "default" + return job_id + +def sync_data(from_path, to_path): + """ + Download data from remote obs to local directory if the first url is remote url and the second one is local path + Upload data from local directory to remote obs in contrast. + """ + import moxing as mox + import time + global _global_sync_count + sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count) + _global_sync_count += 1 + + # Each server contains 8 devices as most. + if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock): + print("from path: ", from_path) + print("to path: ", to_path) + mox.file.copy_parallel(from_path, to_path) + print("===finish data synchronization===") + try: + os.mknod(sync_lock) + # print("os.mknod({}) success".format(sync_lock)) + except IOError: + pass + print("===save flag===") + + while True: + if os.path.exists(sync_lock): + break + time.sleep(1) + + print("Finish sync data from {} to {}.".format(from_path, to_path)) + + +def moxing_wrapper(pre_process=None, post_process=None): + """ + Moxing wrapper to download dataset and upload outputs. + """ + def wrapper(run_func): + @functools.wraps(run_func) + def wrapped_func(*args, **kwargs): + # Download data from data_url + if config.enable_modelarts: + if config.data_url: + sync_data(config.data_url, config.data_path) + print("Dataset downloaded: ", os.listdir(config.data_path)) + if config.checkpoint_url: + sync_data(config.checkpoint_url, config.load_path) + print("Preload downloaded: ", os.listdir(config.load_path)) + if config.train_url: + sync_data(config.train_url, config.output_path) + print("Workspace downloaded: ", os.listdir(config.output_path)) + + context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id()))) + config.device_num = get_device_num() + config.device_id = get_device_id() + if not os.path.exists(config.output_path): + os.makedirs(config.output_path) + + if pre_process: + pre_process() + + if config.enable_profiling: + profiler = Profiler() + + run_func(*args, **kwargs) + + if config.enable_profiling: + profiler.analyse() + + # Upload data to train_url + if config.enable_modelarts: + if post_process: + post_process() + + if config.train_url: + print("Start to copy output directory") + sync_data(config.output_path, config.train_url) + return wrapped_func + return wrapper diff --git a/model_zoo/research/nlp/textrcnn/src/textrcnn.py b/model_zoo/research/nlp/textrcnn/src/textrcnn.py index fdac3a4936b..3041fe45af9 100644 --- a/model_zoo/research/nlp/textrcnn/src/textrcnn.py +++ b/model_zoo/research/nlp/textrcnn/src/textrcnn.py @@ -63,6 +63,10 @@ class textrcnn(nn.Cell): self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens) self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens) self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens) + self.rnnW_fw.to_float(mstype.float16) + self.rnnU_fw.to_float(mstype.float16) + self.rnnW_bw.to_float(mstype.float16) + self.rnnU_bw.to_float(mstype.float16) if cell == "gru": self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens) diff --git a/model_zoo/research/nlp/textrcnn/train.py b/model_zoo/research/nlp/textrcnn/train.py index 4f23836d1d6..8b4626982c7 100644 --- a/model_zoo/research/nlp/textrcnn/train.py +++ b/model_zoo/research/nlp/textrcnn/train.py @@ -25,22 +25,31 @@ from mindspore.nn.metrics import Accuracy from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.common import set_seed -from src.config import textrcnn_cfg as cfg from src.dataset import create_dataset from src.dataset import convert_to_mindrecord from src.textrcnn import textrcnn from src.utils import get_lr +from src.model_utils.moxing_adapter import moxing_wrapper +from src.model_utils.config import config as cfg +from src.model_utils.device_adapter import get_device_id set_seed(0) -if __name__ == '__main__': +def modelarts_pre_process(): + '''modelarts pre process function.''' + cfg.ckpt_folder_path = os.path.join(cfg.output_path, cfg.ckpt_folder_path) + cfg.preprocess_path = cfg.data_path + +@moxing_wrapper(pre_process=modelarts_pre_process) +def run_train(): + '''train function.''' context.set_context( mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend") - device_id = int(os.getenv('DEVICE_ID')) + device_id = get_device_id() context.set_context(device_id=device_id) if cfg.preprocess == 'true': @@ -48,7 +57,7 @@ if __name__ == '__main__': if os.path.exists(cfg.preprocess_path): shutil.rmtree(cfg.preprocess_path) os.mkdir(cfg.preprocess_path) - convert_to_mindrecord(cfg.embed_size, cfg.data_path, cfg.preprocess_path, cfg.emb_path) + convert_to_mindrecord(cfg.embed_size, cfg.data_root, cfg.preprocess_path, cfg.emb_path) if cfg.cell == "vanilla": print("============ Precision is lower than expected when using vanilla RNN architecture ===========") @@ -79,3 +88,6 @@ if __name__ == '__main__': ckpoint_cb = ModelCheckpoint(prefix=cfg.cell, directory=cfg.ckpt_folder_path, config=config_ck) model.train(num_epochs, ds_train, callbacks=[ckpoint_cb, loss_cb, time_cb]) print("train success") + +if __name__ == '__main__': + run_train()