!18043 modify fasttext network for clould

Merge pull request !18043 from zhanghuiyao/fasttext_clould
This commit is contained in:
i-robot 2021-06-09 16:51:36 +08:00 committed by Gitee
commit 0fce707dd9
19 changed files with 748 additions and 246 deletions

View File

@ -63,17 +63,66 @@ architecture. In the following sections, we will introduce how to run the script
After dataset preparation, you can start training and evaluation as follows: After dataset preparation, you can start training and evaluation as follows:
```bash - Running on Ascend
# run training example
cd ./scripts
sh run_standalone_train.sh [TRAIN_DATASET] [DEVICEID]
# run distributed training example ```bash
sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH] # run training example
cd ./scripts
sh run_standalone_train.sh [TRAIN_DATASET] [DEVICEID]
# run evaluation example # run distributed training example
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID] sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
```
# run evaluation example
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID]
```
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
```python
# run standalone training example
# (1) Add "config_path='/path_to_code/[DATASET_NAME]_config.yaml'" on the website UI interface.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on [DATASET_NAME]_config.yaml file.
# Set "dataset_path='/cache/data/[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
# Set "data_name='[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
# (option)Set "device_target='GPU'" on [DATASET_NAME]_config.yaml file if run with GPU.
# (option)Set other parameters on [DATASET_NAME]_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "dataset_path='/cache/data/[DATASET_NAME]'" on the website UI interface.
# Add "data_name='[DATASET_NAME]'" on the website UI interface.
# (option)Set "device_target='GPU'" on the website UI interface if run with GPU.
# (option)Set other parameters on the website UI interface.
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
# (4) Set the code directory to "/path/fasttext" on the website UI interface.
# (5) Set the startup file to "train.py" on the website UI interface.
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (7) Create your job.
#
# run evaluation example
# (1) Add "config_path='/path_to_code/[DATASET_NAME]_config.yaml'" on the website UI interface.
# (2) Perform a or b.
# a. Set "enable_modelarts=True" on [DATASET_NAME]_config.yaml file.
# Set "dataset_path='/cache/data/[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
# Set "data_name='[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on [DATASET_NAME]_config.yaml file.
# Set "model_ckpt='/cache/checkpoint_path/model.ckpt'" on [DATASET_NAME]_config.yaml file.
# (option)Set "device_target='GPU'" on [DATASET_NAME]_config.yaml file if run with GPU.
# (option)Set other parameters on [DATASET_NAME]_config.yaml file you need.
# b. Add "enable_modelarts=True" on the website UI interface.
# Add "dataset_path='/cache/data/[DATASET_NAME]'" on the website UI interface.
# Add "data_name='[DATASET_NAME]'" on the website UI interface.
# Add "checkpoint_url='s3://dir_to_trained_ckpt/'" on the website UI interface.
# Add "model_ckpt='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
# (option)Set "device_target='GPU'" on the website UI interface if run with GPU.
# (option)Set other parameters on the website UI interface.
# (3) Upload or copy your pretrained model to S3 bucket.
# (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
# (5) Set the code directory to "/path/fasttext" on the website UI interface.
# (6) Set the startup file to "train.py" on the website UI interface.
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
# (8) Create your job.
```
## [Script Description](#content) ## [Script Description](#content)
@ -82,8 +131,13 @@ The FastText network script and code result are as follows:
```text ```text
├── fasttext ├── fasttext
├── README.md // Introduction of FastText model. ├── README.md // Introduction of FastText model.
├── model_utils
│ ├──__init__.py // module init file
│ ├──config.py // Parse arguments
│ ├──device_adapter.py // Device adapter for ModelArts
│ ├──local_adapter.py // Local adapter
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
├── src ├── src
│ ├──config.py // Configuration instance definition.
│ ├──create_dataset.py // Dataset preparation. │ ├──create_dataset.py // Dataset preparation.
│ ├──fasttext_model.py // FastText model architecture. │ ├──fasttext_model.py // FastText model architecture.
│ ├──fasttext_train.py // Use FastText model architecture. │ ├──fasttext_train.py // Use FastText model architecture.
@ -96,6 +150,11 @@ The FastText network script and code result are as follows:
│ ├──run_distributed_train_gpu.sh // shell script for distributed train on GPU. │ ├──run_distributed_train_gpu.sh // shell script for distributed train on GPU.
│ ├──run_eval_gpu.sh // shell script for standalone eval on GPU. │ ├──run_eval_gpu.sh // shell script for standalone eval on GPU.
│ ├──run_standalone_train_gpu.sh // shell script for standalone train on GPU. │ ├──run_standalone_train_gpu.sh // shell script for standalone train on GPU.
├── ag_config.yaml // ag dataset arguments
├── dbpedia_config.yaml // dbpedia dataset arguments
├── yelpp_config.yaml // yelpp dataset arguments
├── mindspore_hub_conf.py // mindspore hub scripts
├── export.py // Export API entry.
├── eval.py // Infer API entry. ├── eval.py // Infer API entry.
├── requirements.txt // Requirements of third party package. ├── requirements.txt // Requirements of third party package.
├── train.py // Train API entry. ├── train.py // Train API entry.

View File

@ -0,0 +1,57 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# options
vocab_size: 1383812
buckets: [64, 128, 467]
test_buckets: [467]
batch_size: 512
embedding_dims: 16
num_class: 4
epoch: 5
lr: 0.2
min_lr: 0.000001 # 1e-6
decay_steps: 115
warmup_steps: 400000
poly_lr_scheduler_power: 0.001
epoch_count: 1
pretrain_ckpt_dir: ""
save_ckpt_steps: 116
save_ckpt_dir: "./"
keep_ckpt_max: 10
distribute_batch_size_gpu: 64
dataset_path: ""
data_name: "ag"
run_distribute: False
model_ckpt: ""
# export option
device_id: 0
ckpt_file: ""
file_name: "fasttexts"
file_format: "AIR"
---
# Help description for each configuration
device_target: "Device target"
dataset_path: "FastText input data file path."
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
run_distribute: "Run distribute, default: false."
model_ckpt: "existed checkpoint address."
# export option
device_id: "Device id"
ckpt_file: "Checkpoint file path"
file_name: "Output file name"
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"

View File

@ -0,0 +1,57 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# options
vocab_size: 6596536
buckets: [64, 128, 256, 512, 3013]
test_buckets: [64, 128, 256, 512, 1120]
batch_size: 4096
embedding_dims: 16
num_class: 14
epoch: 5
lr: 0.8
min_lr: 0.000001 # 1e-6
decay_steps: 549
warmup_steps: 400000
poly_lr_scheduler_power: 0.5
epoch_count: 1
pretrain_ckpt_dir: ""
save_ckpt_steps: 548
save_ckpt_dir: "./"
keep_ckpt_max: 10
distribute_batch_size_gpu: 512
dataset_path: ""
data_name: "dbpedia"
run_distribute: False
model_ckpt: ""
# export option
device_id: 0
ckpt_file: ""
file_name: "fasttexts"
file_format: "AIR"
---
# Help description for each configuration
device_target: "Device target"
dataset_path: "FastText input data file path."
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
run_distribute: "Run distribute, default: false."
model_ckpt: "existed checkpoint address."
# export option
device_id: "Device id"
ckpt_file: "Checkpoint file path"
file_name: "Output file name"
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"

View File

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FastText for Evaluation""" """FastText for Evaluation"""
import argparse import time
import os
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
@ -26,33 +27,21 @@ import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
from mindspore import context from mindspore import context
from src.fasttext_model import FastText from src.fasttext_model import FastText
parser = argparse.ArgumentParser(description='fasttext')
parser.add_argument('--data_path', type=str, help='infer dataset path..')
parser.add_argument('--data_name', type=str, required=True, default='ag',
help='dataset name. eg. ag, dbpedia')
parser.add_argument("--model_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
args = parser.parse_args() from model_utils.config import config
if args.data_name == "ag": from model_utils.moxing_adapter import moxing_wrapper
from src.config import config_ag as config_ascend from model_utils.device_adapter import get_device_id, get_device_num
from src.config import config_ag_gpu as config_gpu
if config.data_name == "ag":
target_label1 = ['0', '1', '2', '3'] target_label1 = ['0', '1', '2', '3']
elif args.data_name == 'dbpedia': elif config.data_name == 'dbpedia':
from src.config import config_db as config_ascend
from src.config import config_db_gpu as config_gpu
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13'] target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
elif args.data_name == 'yelp_p': elif config.data_name == 'yelp_p':
from src.config import config_yelpp as config_ascend
from src.config import config_yelpp_gpu as config_gpu
target_label1 = ['0', '1'] target_label1 = ['0', '1']
context.set_context(
mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
save_graphs=False,
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class FastTextInferCell(nn.Cell): class FastTextInferCell(nn.Cell):
""" """
Encapsulation class of FastText network infer. Encapsulation class of FastText network infer.
@ -77,6 +66,7 @@ class FastTextInferCell(nn.Cell):
return predicted_idx return predicted_idx
def load_infer_dataset(batch_size, datafile, bucket): def load_infer_dataset(batch_size, datafile, bucket):
"""data loader for infer""" """data loader for infer"""
def batch_per_bucket(bucket_length, input_file): def batch_per_bucket(bucket_length, input_file):
@ -103,12 +93,66 @@ def load_infer_dataset(batch_size, datafile, bucket):
return data_set return data_set
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_fasttext_infer(): def run_fasttext_infer():
"""run infer with FastText""" """run infer with FastText"""
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path, bucket=config.test_buckets) dataset = load_infer_dataset(batch_size=config.batch_size, datafile=config.dataset_path, bucket=config.test_buckets)
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class) fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
parameter_dict = load_checkpoint(args.model_ckpt) parameter_dict = load_checkpoint(config.model_ckpt)
load_param_into_net(fasttext_model, parameter_dict=parameter_dict) load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
ft_infer = FastTextInferCell(fasttext_model) ft_infer = FastTextInferCell(fasttext_model)

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""export checkpoint file into models""" """export checkpoint file into models"""
import argparse
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -23,36 +22,16 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, export, load_param_into_net from mindspore.train.serialization import load_checkpoint, export, load_param_into_net
from src.fasttext_model import FastText from src.fasttext_model import FastText
parser = argparse.ArgumentParser(description='fasttexts') from model_utils.config import config
parser.add_argument('--device_target', type=str, choices=["Ascend", "GPU", "CPU"],
default='Ascend', help='Device target')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--ckpt_file', type=str, required=True, help='Checkpoint file path')
parser.add_argument('--file_name', type=str, default='fasttexts', help='Output file name')
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR',
help='Output file format')
parser.add_argument('--data_name', type=str, required=True, default='ag',
help='Dataset name. eg. ag, dbpedia, yelp_p')
args = parser.parse_args()
if args.data_name == "ag": if config.data_name == "ag":
from src.config import config_ag as config
from src.config import config_ag_gpu as config_gpu
target_label1 = ['0', '1', '2', '3'] target_label1 = ['0', '1', '2', '3']
elif args.data_name == 'dbpedia': elif config.data_name == 'dbpedia':
from src.config import config_db as config
from src.config import config_db_gpu as config_gpu
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13'] target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
elif args.data_name == 'yelp_p': elif config.data_name == 'yelp_p':
from src.config import config_yelpp as config
from src.config import config_yelpp_gpu as config_gpu
target_label1 = ['0', '1'] target_label1 = ['0', '1']
context.set_context( context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
mode=context.GRAPH_MODE,
save_graphs=False,
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class FastTextInferExportCell(nn.Cell): class FastTextInferExportCell(nn.Cell):
""" """
@ -81,26 +60,26 @@ class FastTextInferExportCell(nn.Cell):
def run_fasttext_export(): def run_fasttext_export():
"""export function""" """export function"""
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class) fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
parameter_dict = load_checkpoint(args.ckpt_file) parameter_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(fasttext_model, parameter_dict) load_param_into_net(fasttext_model, parameter_dict)
ft_infer = FastTextInferExportCell(fasttext_model) ft_infer = FastTextInferExportCell(fasttext_model)
batch_size = config.batch_size batch_size = config.batch_size
if args.device_target == 'GPU': if config.device_target == 'GPU':
batch_size = config.distribute_batch_size batch_size = config.distribute_batch_size_gpu
if args.data_name == "ag": if config.data_name == "ag":
src_tokens_shape = [batch_size, 467] src_tokens_shape = [batch_size, 467]
src_tokens_length_shape = [batch_size, 1] src_tokens_length_shape = [batch_size, 1]
elif args.data_name == 'dbpedia': elif config.data_name == 'dbpedia':
src_tokens_shape = [batch_size, 1120] src_tokens_shape = [batch_size, 1120]
src_tokens_length_shape = [batch_size, 1] src_tokens_length_shape = [batch_size, 1]
elif args.data_name == 'yelp_p': elif config.data_name == 'yelp_p':
src_tokens_shape = [batch_size, 2955] src_tokens_shape = [batch_size, 2955]
src_tokens_length_shape = [batch_size, 1] src_tokens_length_shape = [batch_size, 1]
file_name = args.file_name + '_' + args.data_name file_name = config.file_name + '_' + config.data_name
src_tokens = Tensor(np.ones((src_tokens_shape)).astype(np.int32)) src_tokens = Tensor(np.ones((src_tokens_shape)).astype(np.int32))
src_tokens_length = Tensor(np.ones((src_tokens_length_shape)).astype(np.int32)) src_tokens_length = Tensor(np.ones((src_tokens_length_shape)).astype(np.int32))
export(ft_infer, src_tokens, src_tokens_length, file_name=file_name, file_format=args.file_format) export(ft_infer, src_tokens, src_tokens_length, file_name=file_name, file_format=config.file_format)
if __name__ == '__main__': if __name__ == '__main__':
run_fasttext_export() run_fasttext_export()

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
# Run the main function
run_func(*args, **kwargs)
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -48,6 +48,8 @@ echo $RANK_TABLE_FILE
export RANK_SIZE=8 export RANK_SIZE=8
export DEVICE_NUM=8 export DEVICE_NUM=8
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
for((i=0;i<=7;i++)); for((i=0;i<=7;i++));
do do
@ -55,12 +57,14 @@ do
mkdir ${current_exec_path}/device$i mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit cd ${current_exec_path}/device$i || exit
cp ../../*.py ./ cp ../../*.py ./
cp ../../*.yaml ./
cp -r ../../src ./ cp -r ../../src ./
cp -r ../../model_utils ./
cp -r ../*.sh ./ cp -r ../*.sh ./
export RANK_ID=$i export RANK_ID=$i
export DEVICE_ID=$i export DEVICE_ID=$i
echo "start training for rank $i, device $DEVICE_ID" echo "start training for rank $i, device $DEVICE_ID"
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 & python ../../train.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
cd ${current_exec_path} || exit cd ${current_exec_path} || exit
done done
cd ${current_exec_path} || exit cd ${current_exec_path} || exit

View File

@ -34,6 +34,8 @@ DATANAME=$(basename $DATASET)
echo $DATANAME echo $DATANAME
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
if [ -d "distribute_train" ]; if [ -d "distribute_train" ];
then then
@ -41,11 +43,13 @@ then
fi fi
mkdir ./distribute_train mkdir ./distribute_train
cp ../*.py ./distribute_train cp ../*.py ./distribute_train
cp ../*.yaml ./distribute_train
cp -r ../src ./distribute_train cp -r ../src ./distribute_train
cp -r ../model_utils ./distribute_train
cp -r ../scripts/*.sh ./distribute_train cp -r ../scripts/*.sh ./distribute_train
cd ./distribute_train || exit cd ./distribute_train || exit
echo "start training for $2 GPU devices" echo "start training for $2 GPU devices"
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \ mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
python ../../train.py --device_target GPU --run_distribute True --data_path $DATASET --data_name $DATANAME python ../../train.py --config_path $config_path --device_target GPU --run_distribute True --dataset_path $DATASET --data_name $DATANAME
cd .. cd ..

View File

@ -38,6 +38,8 @@ export DEVICE_ID=$DEVICEID
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
if [ -d "eval" ]; if [ -d "eval" ];
then then
@ -45,10 +47,12 @@ then
fi fi
mkdir ./eval mkdir ./eval
cp ../*.py ./eval cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval cp -r ../src ./eval
cp -r ../model_utils ./eval
cp -r ../scripts/*.sh ./eval cp -r ../scripts/*.sh ./eval
cd ./eval || exit cd ./eval || exit
echo "start training for device $DEVICE_ID" echo "start training for device $DEVICE_ID"
env > env.log env > env.log
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 & python ../../eval.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
cd .. cd ..

View File

@ -33,6 +33,8 @@ echo $DATASET
DATANAME=$2 DATANAME=$2
MODEL_CKPT=$(get_real_path $3) MODEL_CKPT=$(get_real_path $3)
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
if [ -d "eval" ]; if [ -d "eval" ];
then then
@ -40,10 +42,12 @@ then
fi fi
mkdir ./eval mkdir ./eval
cp ../*.py ./eval cp ../*.py ./eval
cp ../*.yaml ./eval
cp -r ../src ./eval cp -r ../src ./eval
cp -r ../model_utils ./eval
cp -r ../scripts/*.sh ./eval cp -r ../scripts/*.sh ./eval
cd ./eval || exit cd ./eval || exit
echo "start eval on standalone GPU" echo "start eval on standalone GPU"
python ../../eval.py --device_target GPU --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 & python ../../eval.py --config_path $config_path --device_target GPU --dataset_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
cd .. cd ..

View File

@ -34,6 +34,9 @@ DATANAME=$(basename $DATASET)
echo $DATANAME echo $DATANAME
DEVICEID=$2 DEVICEID=$2
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
export DEVICE_NUM=1 export DEVICE_NUM=1
export DEVICE_ID=$DEVICEID export DEVICE_ID=$DEVICEID
export RANK_ID=0 export RANK_ID=0
@ -46,10 +49,12 @@ then
fi fi
mkdir ./train mkdir ./train
cp ../*.py ./train cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train cp -r ../src ./train
cp -r ../model_utils ./train
cp -r ../scripts/*.sh ./train cp -r ../scripts/*.sh ./train
cd ./train || exit cd ./train || exit
echo "start training for device $DEVICE_ID" echo "start training for device $DEVICE_ID"
env > env.log env > env.log
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 & python train.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
cd .. cd ..

View File

@ -32,16 +32,21 @@ echo $DATASET
DATANAME=$(basename $DATASET) DATANAME=$(basename $DATASET)
echo $DATANAME echo $DATANAME
config_path="./${DATANAME}_config.yaml"
echo "config path is : ${config_path}"
if [ -d "train" ]; if [ -d "train" ];
then then
rm -rf ./train rm -rf ./train
fi fi
mkdir ./train mkdir ./train
cp ../*.py ./train cp ../*.py ./train
cp ../*.yaml ./train
cp -r ../src ./train cp -r ../src ./train
cp -r ../model_utils ./train
cp -r ../scripts/*.sh ./train cp -r ../scripts/*.sh ./train
cd ./train || exit cd ./train || exit
echo "start training for standalone GPU device" echo "start training for standalone GPU device"
python train.py --device_target="GPU" --data_path=$1 --data_name=$DATANAME > log_fasttext.log 2>&1 & python train.py --config_path=$config_path --device_target="GPU" --dataset_path=$1 --data_name=$DATANAME > log_fasttext.log 2>&1 &
cd .. cd ..

View File

@ -1,135 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config_yelpp = ed({
'vocab_size': 6414979,
'buckets': [64, 128, 256, 512, 2955],
'test_buckets': [64, 128, 256, 512, 2955],
'batch_size': 2048,
'embedding_dims': 16,
'num_class': 2,
'epoch': 5,
'lr': 0.30,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 549,
'keep_ckpt_max': 10,
})
config_db = ed({
'vocab_size': 6596536,
'buckets': [64, 128, 256, 512, 3013],
'test_buckets': [64, 128, 256, 512, 1120],
'batch_size': 4096,
'embedding_dims': 16,
'num_class': 14,
'epoch': 5,
'lr': 0.8,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 548,
'keep_ckpt_max': 10,
})
config_ag = ed({
'vocab_size': 1383812,
'buckets': [64, 128, 467],
'test_buckets': [467],
'batch_size': 512,
'embedding_dims': 16,
'num_class': 4,
'epoch': 5,
'lr': 0.2,
'min_lr': 1e-6,
'decay_steps': 115,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.001,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 116,
'keep_ckpt_max': 10,
})
config_yelpp_gpu = ed({
'vocab_size': 6414979,
'buckets': [64, 128, 256, 512, 2955],
'test_buckets': [64, 128, 256, 512, 2955],
'batch_size': 2048,
'distribute_batch_size': 512,
'embedding_dims': 16,
'num_class': 2,
'epoch': 5,
'lr': 0.30,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 549,
'keep_ckpt_max': 10,
})
config_db_gpu = ed({
'vocab_size': 6596536,
'buckets': [64, 128, 256, 512, 3013],
'test_buckets': [64, 128, 256, 512, 1120],
'batch_size': 4096,
'distribute_batch_size': 512,
'embedding_dims': 16,
'num_class': 14,
'epoch': 5,
'lr': 0.8,
'min_lr': 1e-6,
'decay_steps': 549,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.5,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 548,
'keep_ckpt_max': 10,
})
config_ag_gpu = ed({
'vocab_size': 1383812,
'buckets': [64, 128, 467],
'test_buckets': [467],
'batch_size': 512,
'distribute_batch_size': 64,
'embedding_dims': 16,
'num_class': 4,
'epoch': 5,
'lr': 0.2,
'min_lr': 1e-6,
'decay_steps': 115,
'warmup_steps': 400000,
'poly_lr_scheduler_power': 0.001,
'epoch_count': 1,
'pretrain_ckpt_dir': None,
'save_ckpt_steps': 116,
'keep_ckpt_max': 10,
})

View File

@ -15,9 +15,8 @@
"""FastText for train""" """FastText for train"""
import os import os
import time import time
import argparse
import ast
from mindspore import context from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim import Adam from mindspore.nn.optim import Adam
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.train.model import Model from mindspore.train.model import Model
@ -28,41 +27,23 @@ from mindspore.train.callback import Callback, TimeMonitor
from mindspore.communication import management as MultiDevice from mindspore.communication import management as MultiDevice
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.load_dataset import load_dataset from src.load_dataset import load_dataset
from src.lr_schedule import polynomial_decay_scheduler from src.lr_schedule import polynomial_decay_scheduler
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
parser = argparse.ArgumentParser() from model_utils.config import config
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.') from model_utils.moxing_adapter import moxing_wrapper
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia') from model_utils.device_adapter import get_device_id, get_device_num
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute, default: false.')
args = parser.parse_args()
if args.data_name == "ag":
from src.config import config_ag as config_ascend
from src.config import config_ag_gpu as config_gpu
elif args.data_name == 'dbpedia':
from src.config import config_db as config_ascend
from src.config import config_db_gpu as config_gpu
elif args.data_name == 'yelp_p':
from src.config import config_yelpp as config_ascend
from src.config import config_yelpp_gpu as config_gpu
def get_ms_timestamp(): def get_ms_timestamp():
t = time.time() t = time.time()
return int(round(t * 1000)) return int(round(t * 1000))
set_seed(5) set_seed(5)
time_stamp_init = False time_stamp_init = False
time_stamp_first = 0 time_stamp_first = 0
rank_id = os.getenv('DEVICE_ID') context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target=args.device_target)
config = config_ascend if args.device_target == 'Ascend' else config_gpu
class LossCallBack(Callback): class LossCallBack(Callback):
""" """
@ -142,7 +123,7 @@ def _build_training_pipeline(pre_dataset, run_distribute=False):
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer) net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
net_with_grads.set_train(True) net_with_grads.set_train(True)
model = Model(net_with_grads) model = Model(net_with_grads)
loss_monitor = LossCallBack(rank_ids=rank_id) loss_monitor = LossCallBack(rank_ids=config.rank_id)
dataset_size = pre_dataset.get_dataset_size() dataset_size = pre_dataset.get_dataset_size()
time_monitor = TimeMonitor(data_size=dataset_size) time_monitor = TimeMonitor(data_size=dataset_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps * config.epoch, ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps * config.epoch,
@ -150,12 +131,14 @@ def _build_training_pipeline(pre_dataset, run_distribute=False):
callbacks = [time_monitor, loss_monitor] callbacks = [time_monitor, loss_monitor]
if not run_distribute: if not run_distribute:
ckpt_callback = ModelCheckpoint(prefix='fasttext', ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))), directory=os.path.join(config.save_ckpt_dir,
'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config) config=ckpt_config)
callbacks.append(ckpt_callback) callbacks.append(ckpt_callback)
if run_distribute and MultiDevice.get_rank() % 8 == 0: if run_distribute and MultiDevice.get_rank() % 8 == 0:
ckpt_callback = ModelCheckpoint(prefix='fasttext', ckpt_callback = ModelCheckpoint(prefix='fasttext',
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))), directory=os.path.join(config.save_ckpt_dir,
'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
config=ckpt_config) config=ckpt_config)
callbacks.append(ckpt_callback) callbacks.append(ckpt_callback)
print("Prepare to Training....") print("Prepare to Training....")
@ -186,6 +169,7 @@ def set_parallel_env():
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiDevice.get_group_size(), device_num=MultiDevice.get_group_size(),
gradients_mean=True) gradients_mean=True)
def train_paralle(input_file_path): def train_paralle(input_file_path):
""" """
Train model on multi device Train model on multi device
@ -195,8 +179,8 @@ def train_paralle(input_file_path):
set_parallel_env() set_parallel_env()
print("Starting traning on multiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|") print("Starting traning on multiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
batch_size = config.batch_size batch_size = config.batch_size
if args.device_target == 'GPU': if config.device_target == 'GPU':
batch_size = config.distribute_batch_size batch_size = config.distribute_batch_size_gpu
preprocessed_data = load_dataset(dataset_path=input_file_path, preprocessed_data = load_dataset(dataset_path=input_file_path,
batch_size=batch_size, batch_size=batch_size,
@ -207,8 +191,76 @@ def train_paralle(input_file_path):
shuffle=False) shuffle=False)
_build_training_pipeline(preprocessed_data, True) _build_training_pipeline(preprocessed_data, True)
if __name__ == "__main__":
if args.run_distribute: def modelarts_pre_process():
train_paralle(args.data_path) '''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.save_ckpt_dir = os.path.join(config.output_path, config.save_ckpt_dir)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_train():
'''run train.'''
if config.device_target == "Ascend":
config.rank_id = get_device_id()
elif config.device_target == "GPU":
init("nccl")
config.rank_id = get_rank()
else: else:
train_single(args.data_path) raise ValueError("Not support device target: {}".format(config.device_target))
if config.run_distribute:
train_paralle(config.dataset_path)
else:
train_single(config.dataset_path)
if __name__ == "__main__":
run_train()

View File

@ -0,0 +1,57 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# options
vocab_size: 6414979
buckets: [64, 128, 256, 512, 2955]
test_buckets: [64, 128, 256, 512, 2955]
batch_size: 2048
embedding_dims: 16
num_class: 2
epoch: 5
lr: 0.30
min_lr: 0.000001 # 1e-6
decay_steps: 549
warmup_steps: 400000
poly_lr_scheduler_power: 0.5
epoch_count: 1
pretrain_ckpt_dir: ""
save_ckpt_steps: 549
save_ckpt_dir: "./"
keep_ckpt_max: 10
distribute_batch_size_gpu: 512
dataset_path: ""
data_name: "yelp_p"
run_distribute: False
model_ckpt: ""
# export option
device_id: 0
ckpt_file: ""
file_name: "fasttexts"
file_format: "AIR"
---
# Help description for each configuration
device_target: "Device target"
dataset_path: "FastText input data file path."
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
run_distribute: "Run distribute, default: false."
model_ckpt: "existed checkpoint address."
# export option
device_id: "Device id"
ckpt_file: "Checkpoint file path"
file_name: "Output file name"
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"