forked from mindspore-Ecosystem/mindspore
!18043 modify fasttext network for clould
Merge pull request !18043 from zhanghuiyao/fasttext_clould
This commit is contained in:
commit
0fce707dd9
|
@ -63,17 +63,66 @@ architecture. In the following sections, we will introduce how to run the script
|
|||
|
||||
After dataset preparation, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET] [DEVICEID]
|
||||
- Running on Ascend
|
||||
|
||||
# run distributed training example
|
||||
sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
|
||||
```bash
|
||||
# run training example
|
||||
cd ./scripts
|
||||
sh run_standalone_train.sh [TRAIN_DATASET] [DEVICEID]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID]
|
||||
```
|
||||
# run distributed training example
|
||||
sh run_distribute_train.sh [TRAIN_DATASET] [RANK_TABLE_PATH]
|
||||
|
||||
# run evaluation example
|
||||
sh run_eval.sh [EVAL_DATASET_PATH] [DATASET_NAME] [MODEL_CKPT] [DEVICEID]
|
||||
```
|
||||
|
||||
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||
|
||||
```python
|
||||
# run standalone training example
|
||||
# (1) Add "config_path='/path_to_code/[DATASET_NAME]_config.yaml'" on the website UI interface.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "dataset_path='/cache/data/[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "data_name='[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
|
||||
# (option)Set "device_target='GPU'" on [DATASET_NAME]_config.yaml file if run with GPU.
|
||||
# (option)Set other parameters on [DATASET_NAME]_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_path='/cache/data/[DATASET_NAME]'" on the website UI interface.
|
||||
# Add "data_name='[DATASET_NAME]'" on the website UI interface.
|
||||
# (option)Set "device_target='GPU'" on the website UI interface if run with GPU.
|
||||
# (option)Set other parameters on the website UI interface.
|
||||
# (3) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (4) Set the code directory to "/path/fasttext" on the website UI interface.
|
||||
# (5) Set the startup file to "train.py" on the website UI interface.
|
||||
# (6) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (7) Create your job.
|
||||
#
|
||||
# run evaluation example
|
||||
# (1) Add "config_path='/path_to_code/[DATASET_NAME]_config.yaml'" on the website UI interface.
|
||||
# (2) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "dataset_path='/cache/data/[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "data_name='[DATASET_NAME]'" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "checkpoint_url='s3://dir_to_trained_ckpt/'" on [DATASET_NAME]_config.yaml file.
|
||||
# Set "model_ckpt='/cache/checkpoint_path/model.ckpt'" on [DATASET_NAME]_config.yaml file.
|
||||
# (option)Set "device_target='GPU'" on [DATASET_NAME]_config.yaml file if run with GPU.
|
||||
# (option)Set other parameters on [DATASET_NAME]_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "dataset_path='/cache/data/[DATASET_NAME]'" on the website UI interface.
|
||||
# Add "data_name='[DATASET_NAME]'" on the website UI interface.
|
||||
# Add "checkpoint_url='s3://dir_to_trained_ckpt/'" on the website UI interface.
|
||||
# Add "model_ckpt='/cache/checkpoint_path/model.ckpt'" on the website UI interface.
|
||||
# (option)Set "device_target='GPU'" on the website UI interface if run with GPU.
|
||||
# (option)Set other parameters on the website UI interface.
|
||||
# (3) Upload or copy your pretrained model to S3 bucket.
|
||||
# (4) Upload a zip dataset to S3 bucket. (you could also upload the origin dataset, but it can be so slow.)
|
||||
# (5) Set the code directory to "/path/fasttext" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#content)
|
||||
|
||||
|
@ -82,8 +131,13 @@ The FastText network script and code result are as follows:
|
|||
```text
|
||||
├── fasttext
|
||||
├── README.md // Introduction of FastText model.
|
||||
├── model_utils
|
||||
│ ├──__init__.py // module init file
|
||||
│ ├──config.py // Parse arguments
|
||||
│ ├──device_adapter.py // Device adapter for ModelArts
|
||||
│ ├──local_adapter.py // Local adapter
|
||||
│ ├──moxing_adapter.py // Moxing adapter for ModelArts
|
||||
├── src
|
||||
│ ├──config.py // Configuration instance definition.
|
||||
│ ├──create_dataset.py // Dataset preparation.
|
||||
│ ├──fasttext_model.py // FastText model architecture.
|
||||
│ ├──fasttext_train.py // Use FastText model architecture.
|
||||
|
@ -96,6 +150,11 @@ The FastText network script and code result are as follows:
|
|||
│ ├──run_distributed_train_gpu.sh // shell script for distributed train on GPU.
|
||||
│ ├──run_eval_gpu.sh // shell script for standalone eval on GPU.
|
||||
│ ├──run_standalone_train_gpu.sh // shell script for standalone train on GPU.
|
||||
├── ag_config.yaml // ag dataset arguments
|
||||
├── dbpedia_config.yaml // dbpedia dataset arguments
|
||||
├── yelpp_config.yaml // yelpp dataset arguments
|
||||
├── mindspore_hub_conf.py // mindspore hub scripts
|
||||
├── export.py // Export API entry.
|
||||
├── eval.py // Infer API entry.
|
||||
├── requirements.txt // Requirements of third party package.
|
||||
├── train.py // Train API entry.
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
vocab_size: 1383812
|
||||
buckets: [64, 128, 467]
|
||||
test_buckets: [467]
|
||||
batch_size: 512
|
||||
embedding_dims: 16
|
||||
num_class: 4
|
||||
epoch: 5
|
||||
lr: 0.2
|
||||
min_lr: 0.000001 # 1e-6
|
||||
decay_steps: 115
|
||||
warmup_steps: 400000
|
||||
poly_lr_scheduler_power: 0.001
|
||||
epoch_count: 1
|
||||
pretrain_ckpt_dir: ""
|
||||
save_ckpt_steps: 116
|
||||
save_ckpt_dir: "./"
|
||||
keep_ckpt_max: 10
|
||||
distribute_batch_size_gpu: 64
|
||||
|
||||
dataset_path: ""
|
||||
data_name: "ag"
|
||||
run_distribute: False
|
||||
model_ckpt: ""
|
||||
# export option
|
||||
device_id: 0
|
||||
ckpt_file: ""
|
||||
file_name: "fasttexts"
|
||||
file_format: "AIR"
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
device_target: "Device target"
|
||||
dataset_path: "FastText input data file path."
|
||||
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
|
||||
run_distribute: "Run distribute, default: false."
|
||||
model_ckpt: "existed checkpoint address."
|
||||
# export option
|
||||
device_id: "Device id"
|
||||
ckpt_file: "Checkpoint file path"
|
||||
file_name: "Output file name"
|
||||
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
vocab_size: 6596536
|
||||
buckets: [64, 128, 256, 512, 3013]
|
||||
test_buckets: [64, 128, 256, 512, 1120]
|
||||
batch_size: 4096
|
||||
embedding_dims: 16
|
||||
num_class: 14
|
||||
epoch: 5
|
||||
lr: 0.8
|
||||
min_lr: 0.000001 # 1e-6
|
||||
decay_steps: 549
|
||||
warmup_steps: 400000
|
||||
poly_lr_scheduler_power: 0.5
|
||||
epoch_count: 1
|
||||
pretrain_ckpt_dir: ""
|
||||
save_ckpt_steps: 548
|
||||
save_ckpt_dir: "./"
|
||||
keep_ckpt_max: 10
|
||||
distribute_batch_size_gpu: 512
|
||||
|
||||
dataset_path: ""
|
||||
data_name: "dbpedia"
|
||||
run_distribute: False
|
||||
model_ckpt: ""
|
||||
# export option
|
||||
device_id: 0
|
||||
ckpt_file: ""
|
||||
file_name: "fasttexts"
|
||||
file_format: "AIR"
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
device_target: "Device target"
|
||||
dataset_path: "FastText input data file path."
|
||||
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
|
||||
run_distribute: "Run distribute, default: false."
|
||||
model_ckpt: "existed checkpoint address."
|
||||
# export option
|
||||
device_id: "Device id"
|
||||
ckpt_file: "Checkpoint file path"
|
||||
file_name: "Output file name"
|
||||
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"
|
|
@ -13,7 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FastText for Evaluation"""
|
||||
import argparse
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -26,33 +27,21 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from mindspore import context
|
||||
from src.fasttext_model import FastText
|
||||
parser = argparse.ArgumentParser(description='fasttext')
|
||||
parser.add_argument('--data_path', type=str, help='infer dataset path..')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag',
|
||||
help='dataset name. eg. ag, dbpedia')
|
||||
parser.add_argument("--model_ckpt", type=str, required=True,
|
||||
help="existed checkpoint address.")
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config_ascend
|
||||
from src.config import config_ag_gpu as config_gpu
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
if config.data_name == "ag":
|
||||
target_label1 = ['0', '1', '2', '3']
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config_ascend
|
||||
from src.config import config_db_gpu as config_gpu
|
||||
elif config.data_name == 'dbpedia':
|
||||
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config_ascend
|
||||
from src.config import config_yelpp_gpu as config_gpu
|
||||
elif config.data_name == 'yelp_p':
|
||||
target_label1 = ['0', '1']
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target=args.device_target)
|
||||
config = config_ascend if args.device_target == 'Ascend' else config_gpu
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
|
||||
|
||||
|
||||
class FastTextInferCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of FastText network infer.
|
||||
|
@ -77,6 +66,7 @@ class FastTextInferCell(nn.Cell):
|
|||
|
||||
return predicted_idx
|
||||
|
||||
|
||||
def load_infer_dataset(batch_size, datafile, bucket):
|
||||
"""data loader for infer"""
|
||||
def batch_per_bucket(bucket_length, input_file):
|
||||
|
@ -103,12 +93,66 @@ def load_infer_dataset(batch_size, datafile, bucket):
|
|||
|
||||
return data_set
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_fasttext_infer():
|
||||
"""run infer with FastText"""
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=args.data_path, bucket=config.test_buckets)
|
||||
dataset = load_infer_dataset(batch_size=config.batch_size, datafile=config.dataset_path, bucket=config.test_buckets)
|
||||
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
|
||||
parameter_dict = load_checkpoint(args.model_ckpt)
|
||||
parameter_dict = load_checkpoint(config.model_ckpt)
|
||||
load_param_into_net(fasttext_model, parameter_dict=parameter_dict)
|
||||
|
||||
ft_infer = FastTextInferCell(fasttext_model)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""export checkpoint file into models"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -23,36 +22,16 @@ from mindspore import context
|
|||
from mindspore.train.serialization import load_checkpoint, export, load_param_into_net
|
||||
from src.fasttext_model import FastText
|
||||
|
||||
parser = argparse.ArgumentParser(description='fasttexts')
|
||||
parser.add_argument('--device_target', type=str, choices=["Ascend", "GPU", "CPU"],
|
||||
default='Ascend', help='Device target')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
parser.add_argument('--ckpt_file', type=str, required=True, help='Checkpoint file path')
|
||||
parser.add_argument('--file_name', type=str, default='fasttexts', help='Output file name')
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR',
|
||||
help='Output file format')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag',
|
||||
help='Dataset name. eg. ag, dbpedia, yelp_p')
|
||||
args = parser.parse_args()
|
||||
from model_utils.config import config
|
||||
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config
|
||||
from src.config import config_ag_gpu as config_gpu
|
||||
if config.data_name == "ag":
|
||||
target_label1 = ['0', '1', '2', '3']
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config
|
||||
from src.config import config_db_gpu as config_gpu
|
||||
elif config.data_name == 'dbpedia':
|
||||
target_label1 = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13']
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config
|
||||
from src.config import config_yelpp_gpu as config_gpu
|
||||
elif config.data_name == 'yelp_p':
|
||||
target_label1 = ['0', '1']
|
||||
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target=args.device_target)
|
||||
config = config_ascend if args.device_target == 'Ascend' else config_gpu
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
|
||||
|
||||
class FastTextInferExportCell(nn.Cell):
|
||||
"""
|
||||
|
@ -81,26 +60,26 @@ class FastTextInferExportCell(nn.Cell):
|
|||
def run_fasttext_export():
|
||||
"""export function"""
|
||||
fasttext_model = FastText(config.vocab_size, config.embedding_dims, config.num_class)
|
||||
parameter_dict = load_checkpoint(args.ckpt_file)
|
||||
parameter_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(fasttext_model, parameter_dict)
|
||||
ft_infer = FastTextInferExportCell(fasttext_model)
|
||||
batch_size = config.batch_size
|
||||
if args.device_target == 'GPU':
|
||||
batch_size = config.distribute_batch_size
|
||||
if args.data_name == "ag":
|
||||
if config.device_target == 'GPU':
|
||||
batch_size = config.distribute_batch_size_gpu
|
||||
if config.data_name == "ag":
|
||||
src_tokens_shape = [batch_size, 467]
|
||||
src_tokens_length_shape = [batch_size, 1]
|
||||
elif args.data_name == 'dbpedia':
|
||||
elif config.data_name == 'dbpedia':
|
||||
src_tokens_shape = [batch_size, 1120]
|
||||
src_tokens_length_shape = [batch_size, 1]
|
||||
elif args.data_name == 'yelp_p':
|
||||
elif config.data_name == 'yelp_p':
|
||||
src_tokens_shape = [batch_size, 2955]
|
||||
src_tokens_length_shape = [batch_size, 1]
|
||||
|
||||
file_name = args.file_name + '_' + args.data_name
|
||||
file_name = config.file_name + '_' + config.data_name
|
||||
src_tokens = Tensor(np.ones((src_tokens_shape)).astype(np.int32))
|
||||
src_tokens_length = Tensor(np.ones((src_tokens_length_shape)).astype(np.int32))
|
||||
export(ft_infer, src_tokens, src_tokens_length, file_name=file_name, file_format=args.file_format)
|
||||
export(ft_infer, src_tokens, src_tokens_length, file_name=file_name, file_format=config.file_format)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_fasttext_export()
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from .config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from .config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
# Run the main function
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -48,6 +48,8 @@ echo $RANK_TABLE_FILE
|
|||
export RANK_SIZE=8
|
||||
export DEVICE_NUM=8
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
for((i=0;i<=7;i++));
|
||||
do
|
||||
|
@ -55,12 +57,14 @@ do
|
|||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i || exit
|
||||
cp ../../*.py ./
|
||||
cp ../../*.yaml ./
|
||||
cp -r ../../src ./
|
||||
cp -r ../../model_utils ./
|
||||
cp -r ../*.sh ./
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
python ../../train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
python ../../train.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ${current_exec_path} || exit
|
||||
done
|
||||
cd ${current_exec_path} || exit
|
||||
|
|
|
@ -34,6 +34,8 @@ DATANAME=$(basename $DATASET)
|
|||
|
||||
echo $DATANAME
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
if [ -d "distribute_train" ];
|
||||
then
|
||||
|
@ -41,11 +43,13 @@ then
|
|||
fi
|
||||
mkdir ./distribute_train
|
||||
cp ../*.py ./distribute_train
|
||||
cp ../*.yaml ./distribute_train
|
||||
cp -r ../src ./distribute_train
|
||||
cp -r ../model_utils ./distribute_train
|
||||
cp -r ../scripts/*.sh ./distribute_train
|
||||
cd ./distribute_train || exit
|
||||
echo "start training for $2 GPU devices"
|
||||
|
||||
mpirun -n $2 --allow-run-as-root --output-filename log_output --merge-stderr-to-stdout \
|
||||
python ../../train.py --device_target GPU --run_distribute True --data_path $DATASET --data_name $DATANAME
|
||||
python ../../train.py --config_path $config_path --device_target GPU --run_distribute True --dataset_path $DATASET --data_name $DATANAME
|
||||
cd ..
|
||||
|
|
|
@ -38,6 +38,8 @@ export DEVICE_ID=$DEVICEID
|
|||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
|
@ -45,10 +47,12 @@ then
|
|||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cp -r ../scripts/*.sh ./eval
|
||||
cd ./eval || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../../eval.py --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
python ../../eval.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -33,6 +33,8 @@ echo $DATASET
|
|||
DATANAME=$2
|
||||
MODEL_CKPT=$(get_real_path $3)
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
|
@ -40,10 +42,12 @@ then
|
|||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp ../*.yaml ./eval
|
||||
cp -r ../src ./eval
|
||||
cp -r ../model_utils ./eval
|
||||
cp -r ../scripts/*.sh ./eval
|
||||
cd ./eval || exit
|
||||
echo "start eval on standalone GPU"
|
||||
|
||||
python ../../eval.py --device_target GPU --data_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
python ../../eval.py --config_path $config_path --device_target GPU --dataset_path $DATASET --data_name $DATANAME --model_ckpt $MODEL_CKPT> log_fasttext.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -34,6 +34,9 @@ DATANAME=$(basename $DATASET)
|
|||
echo $DATANAME
|
||||
DEVICEID=$2
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$DEVICEID
|
||||
export RANK_ID=0
|
||||
|
@ -46,10 +49,12 @@ then
|
|||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cp -r ../scripts/*.sh ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
python train.py --config_path $config_path --dataset_path $DATASET --data_name $DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ..
|
||||
|
|
|
@ -32,16 +32,21 @@ echo $DATASET
|
|||
DATANAME=$(basename $DATASET)
|
||||
echo $DATANAME
|
||||
|
||||
config_path="./${DATANAME}_config.yaml"
|
||||
echo "config path is : ${config_path}"
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp ../*.yaml ./train
|
||||
cp -r ../src ./train
|
||||
cp -r ../model_utils ./train
|
||||
cp -r ../scripts/*.sh ./train
|
||||
cd ./train || exit
|
||||
echo "start training for standalone GPU device"
|
||||
|
||||
python train.py --device_target="GPU" --data_path=$1 --data_name=$DATANAME > log_fasttext.log 2>&1 &
|
||||
python train.py --config_path=$config_path --device_target="GPU" --dataset_path=$1 --data_name=$DATANAME > log_fasttext.log 2>&1 &
|
||||
cd ..
|
|
@ -1,135 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" :===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config_yelpp = ed({
|
||||
'vocab_size': 6414979,
|
||||
'buckets': [64, 128, 256, 512, 2955],
|
||||
'test_buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 2048,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 2,
|
||||
'epoch': 5,
|
||||
'lr': 0.30,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 549,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_db = ed({
|
||||
'vocab_size': 6596536,
|
||||
'buckets': [64, 128, 256, 512, 3013],
|
||||
'test_buckets': [64, 128, 256, 512, 1120],
|
||||
'batch_size': 4096,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 14,
|
||||
'epoch': 5,
|
||||
'lr': 0.8,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 548,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_ag = ed({
|
||||
'vocab_size': 1383812,
|
||||
'buckets': [64, 128, 467],
|
||||
'test_buckets': [467],
|
||||
'batch_size': 512,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 4,
|
||||
'epoch': 5,
|
||||
'lr': 0.2,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 115,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.001,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 116,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_yelpp_gpu = ed({
|
||||
'vocab_size': 6414979,
|
||||
'buckets': [64, 128, 256, 512, 2955],
|
||||
'test_buckets': [64, 128, 256, 512, 2955],
|
||||
'batch_size': 2048,
|
||||
'distribute_batch_size': 512,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 2,
|
||||
'epoch': 5,
|
||||
'lr': 0.30,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 549,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_db_gpu = ed({
|
||||
'vocab_size': 6596536,
|
||||
'buckets': [64, 128, 256, 512, 3013],
|
||||
'test_buckets': [64, 128, 256, 512, 1120],
|
||||
'batch_size': 4096,
|
||||
'distribute_batch_size': 512,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 14,
|
||||
'epoch': 5,
|
||||
'lr': 0.8,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 549,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.5,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 548,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
||||
|
||||
config_ag_gpu = ed({
|
||||
'vocab_size': 1383812,
|
||||
'buckets': [64, 128, 467],
|
||||
'test_buckets': [467],
|
||||
'batch_size': 512,
|
||||
'distribute_batch_size': 64,
|
||||
'embedding_dims': 16,
|
||||
'num_class': 4,
|
||||
'epoch': 5,
|
||||
'lr': 0.2,
|
||||
'min_lr': 1e-6,
|
||||
'decay_steps': 115,
|
||||
'warmup_steps': 400000,
|
||||
'poly_lr_scheduler_power': 0.001,
|
||||
'epoch_count': 1,
|
||||
'pretrain_ckpt_dir': None,
|
||||
'save_ckpt_steps': 116,
|
||||
'keep_ckpt_max': 10,
|
||||
})
|
|
@ -15,9 +15,8 @@
|
|||
"""FastText for train"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
|
@ -28,41 +27,23 @@ from mindspore.train.callback import Callback, TimeMonitor
|
|||
from mindspore.communication import management as MultiDevice
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.load_dataset import load_dataset
|
||||
from src.lr_schedule import polynomial_decay_scheduler
|
||||
from src.fasttext_train import FastTextTrainOneStepCell, FastTextNetWithLoss
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, required=True, help='FastText input data file path.')
|
||||
parser.add_argument('--data_name', type=str, required=True, default='ag', help='dataset name. eg. ag, dbpedia')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute, default: false.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.data_name == "ag":
|
||||
from src.config import config_ag as config_ascend
|
||||
from src.config import config_ag_gpu as config_gpu
|
||||
elif args.data_name == 'dbpedia':
|
||||
from src.config import config_db as config_ascend
|
||||
from src.config import config_db_gpu as config_gpu
|
||||
elif args.data_name == 'yelp_p':
|
||||
from src.config import config_yelpp as config_ascend
|
||||
from src.config import config_yelpp_gpu as config_gpu
|
||||
from model_utils.config import config
|
||||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
|
||||
set_seed(5)
|
||||
time_stamp_init = False
|
||||
time_stamp_first = 0
|
||||
rank_id = os.getenv('DEVICE_ID')
|
||||
context.set_context(
|
||||
mode=context.GRAPH_MODE,
|
||||
save_graphs=False,
|
||||
device_target=args.device_target)
|
||||
config = config_ascend if args.device_target == 'Ascend' else config_gpu
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target=config.device_target)
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
|
@ -142,7 +123,7 @@ def _build_training_pipeline(pre_dataset, run_distribute=False):
|
|||
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
net_with_grads.set_train(True)
|
||||
model = Model(net_with_grads)
|
||||
loss_monitor = LossCallBack(rank_ids=rank_id)
|
||||
loss_monitor = LossCallBack(rank_ids=config.rank_id)
|
||||
dataset_size = pre_dataset.get_dataset_size()
|
||||
time_monitor = TimeMonitor(data_size=dataset_size)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=decay_steps * config.epoch,
|
||||
|
@ -150,12 +131,14 @@ def _build_training_pipeline(pre_dataset, run_distribute=False):
|
|||
callbacks = [time_monitor, loss_monitor]
|
||||
if not run_distribute:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
directory=os.path.join(config.save_ckpt_dir,
|
||||
'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
if run_distribute and MultiDevice.get_rank() % 8 == 0:
|
||||
ckpt_callback = ModelCheckpoint(prefix='fasttext',
|
||||
directory=os.path.join('./', 'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
directory=os.path.join(config.save_ckpt_dir,
|
||||
'ckpt_{}'.format(os.getenv("DEVICE_ID"))),
|
||||
config=ckpt_config)
|
||||
callbacks.append(ckpt_callback)
|
||||
print("Prepare to Training....")
|
||||
|
@ -186,6 +169,7 @@ def set_parallel_env():
|
|||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiDevice.get_group_size(),
|
||||
gradients_mean=True)
|
||||
|
||||
def train_paralle(input_file_path):
|
||||
"""
|
||||
Train model on multi device
|
||||
|
@ -195,8 +179,8 @@ def train_paralle(input_file_path):
|
|||
set_parallel_env()
|
||||
print("Starting traning on multiple devices. |~ _ ~| |~ _ ~| |~ _ ~| |~ _ ~|")
|
||||
batch_size = config.batch_size
|
||||
if args.device_target == 'GPU':
|
||||
batch_size = config.distribute_batch_size
|
||||
if config.device_target == 'GPU':
|
||||
batch_size = config.distribute_batch_size_gpu
|
||||
|
||||
preprocessed_data = load_dataset(dataset_path=input_file_path,
|
||||
batch_size=batch_size,
|
||||
|
@ -207,8 +191,76 @@ def train_paralle(input_file_path):
|
|||
shuffle=False)
|
||||
_build_training_pipeline(preprocessed_data, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.run_distribute:
|
||||
train_paralle(args.data_path)
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
import zipfile
|
||||
s_time = time.time()
|
||||
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
|
||||
zip_isexist = zipfile.is_zipfile(zip_file)
|
||||
if zip_isexist:
|
||||
fz = zipfile.ZipFile(zip_file, 'r')
|
||||
data_num = len(fz.namelist())
|
||||
print("Extract Start...")
|
||||
print("unzip file num: {}".format(data_num))
|
||||
data_print = int(data_num / 100) if data_num > 100 else 1
|
||||
i = 0
|
||||
for file in fz.namelist():
|
||||
if i % data_print == 0:
|
||||
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
|
||||
i += 1
|
||||
fz.extract(file, save_dir)
|
||||
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
|
||||
int(int(time.time() - s_time) % 60)))
|
||||
print("Extract Done.")
|
||||
else:
|
||||
print("This is not zip.")
|
||||
else:
|
||||
print("Zip has been extracted.")
|
||||
|
||||
if config.need_modelarts_dataset_unzip:
|
||||
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
|
||||
save_dir_1 = os.path.join(config.data_path)
|
||||
|
||||
sync_lock = "/tmp/unzip_sync.lock"
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("Zip file path: ", zip_file_1)
|
||||
print("Unzip file save dir: ", save_dir_1)
|
||||
unzip(zip_file_1, save_dir_1)
|
||||
print("===Finish extract data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
|
||||
|
||||
config.save_ckpt_dir = os.path.join(config.output_path, config.save_ckpt_dir)
|
||||
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
'''run train.'''
|
||||
if config.device_target == "Ascend":
|
||||
config.rank_id = get_device_id()
|
||||
elif config.device_target == "GPU":
|
||||
init("nccl")
|
||||
config.rank_id = get_rank()
|
||||
else:
|
||||
train_single(args.data_path)
|
||||
raise ValueError("Not support device target: {}".format(config.device_target))
|
||||
|
||||
if config.run_distribute:
|
||||
train_paralle(config.dataset_path)
|
||||
else:
|
||||
train_single(config.dataset_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_train()
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
need_modelarts_dataset_unzip: False
|
||||
modelarts_dataset_unzip_name: ""
|
||||
|
||||
# ==============================================================================
|
||||
# options
|
||||
vocab_size: 6414979
|
||||
buckets: [64, 128, 256, 512, 2955]
|
||||
test_buckets: [64, 128, 256, 512, 2955]
|
||||
batch_size: 2048
|
||||
embedding_dims: 16
|
||||
num_class: 2
|
||||
epoch: 5
|
||||
lr: 0.30
|
||||
min_lr: 0.000001 # 1e-6
|
||||
decay_steps: 549
|
||||
warmup_steps: 400000
|
||||
poly_lr_scheduler_power: 0.5
|
||||
epoch_count: 1
|
||||
pretrain_ckpt_dir: ""
|
||||
save_ckpt_steps: 549
|
||||
save_ckpt_dir: "./"
|
||||
keep_ckpt_max: 10
|
||||
distribute_batch_size_gpu: 512
|
||||
|
||||
dataset_path: ""
|
||||
data_name: "yelp_p"
|
||||
run_distribute: False
|
||||
model_ckpt: ""
|
||||
# export option
|
||||
device_id: 0
|
||||
ckpt_file: ""
|
||||
file_name: "fasttexts"
|
||||
file_format: "AIR"
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
device_target: "Device target"
|
||||
dataset_path: "FastText input data file path."
|
||||
data_name: "dataset name. choice in ['ag', 'dbpedia', 'yelp_p']"
|
||||
run_distribute: "Run distribute, default: false."
|
||||
model_ckpt: "existed checkpoint address."
|
||||
# export option
|
||||
device_id: "Device id"
|
||||
ckpt_file: "Checkpoint file path"
|
||||
file_name: "Output file name"
|
||||
file_format: "Output file format, choice in ['AIR', 'ONNX', 'MINDIR']"
|
Loading…
Reference in New Issue