forked from mindspore-Ecosystem/mindspore
autodis_can_been_used_on_ModelArts
This commit is contained in:
parent
63d853cf35
commit
924e8f3aa6
|
@ -57,7 +57,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
```python
|
||||
# run training example
|
||||
python train.py \
|
||||
--dataset_path='dataset/train' \
|
||||
--train_data_dir='dataset/train' \
|
||||
--ckpt_path='./checkpoint' \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
|
@ -66,11 +66,11 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
# run evaluation example
|
||||
python eval.py \
|
||||
--dataset_path='dataset/test' \
|
||||
--test_data_dir='dataset/test' \
|
||||
--checkpoint_path='./checkpoint/autodis.ckpt' \
|
||||
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
|
||||
sh scripts/run_eval.sh 0 Ascend /test_data_dir /checkpoint_path/autodis.ckpt
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
@ -79,6 +79,50 @@ After installing MindSpore via the official website, you can start training and
|
|||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
|
||||
|
||||
- running on ModelArts
|
||||
|
||||
If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows
|
||||
|
||||
- Training with single cards on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/autodis" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/autodis/train.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/autodis/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# (6) Upload the dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path.
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of single cards.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
- evaluating with single card on ModelArts
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/autodis" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/autodis/eval.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/autodis/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “checkpoint_path: ./{path}/*.ckpt”('checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# 2. Add “checkpoint_path=./{path}/*.ckpt”('checkpoint_path' indicates the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.)
|
||||
# (6) Upload the dataset to S3 bucket.
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (there is only data or zip package under this path).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
@ -86,54 +130,52 @@ After installing MindSpore via the official website, you can start training and
|
|||
```bash
|
||||
.
|
||||
└─autodis
|
||||
├─README.md
|
||||
├─mindspore_hub_conf.md # config for mindspore hub
|
||||
├─README.md # descriptions of warpctc
|
||||
├─ascend310_infer # application for 310 inference
|
||||
├─scripts
|
||||
├─run_standalone_train.sh # launch standalone training(1p) in Ascend or GPU
|
||||
├─run_infer_310.sh # launch 310infer
|
||||
└─run_eval.sh # launch evaluating in Ascend or GPU
|
||||
├─src
|
||||
├─model_utils
|
||||
├─config.py # parsing parameter configuration file of "*.yaml"
|
||||
├─device_adapter.py # local or ModelArts training
|
||||
├─local_adapter.py # get related environment variables in local training
|
||||
└─moxing_adapter.py # get related environment variables in ModelArts training
|
||||
├─__init__.py # python init file
|
||||
├─config.py # parameter configuration
|
||||
├─callback.py # define callback function
|
||||
├─autodis.py # AutoDis network
|
||||
├─dataset.py # create dataset for AutoDis
|
||||
├─eval.py # eval net
|
||||
└─train.py # train net
|
||||
└─dataset.py # create dataset for AutoDis
|
||||
├─default_config.yaml # parameter configuration
|
||||
├─eval.py # eval script
|
||||
├─export.py # export checkpoint file into air/mindir
|
||||
├─mindspore_hub_conf.py # mindspore hub interface
|
||||
├─postprocess.py # 310infer postprocess script
|
||||
├─preprocess.py # 310infer preprocess script
|
||||
└─train.py # train script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in `default_config.yaml`
|
||||
|
||||
- train parameters
|
||||
|
||||
```python
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--dataset_path DATASET_PATH
|
||||
Dataset path
|
||||
--ckpt_path CKPT_PATH
|
||||
Checkpoint path
|
||||
--eval_file_name EVAL_FILE_NAME
|
||||
Auc log file path. Default: "./auc.log"
|
||||
--loss_file_name LOSS_FILE_NAME
|
||||
Loss log file path. Default: "./loss.log"
|
||||
--do_eval DO_EVAL Do evaluation or not. Default: True
|
||||
--device_target DEVICE_TARGET
|
||||
Ascend or GPU. Default: Ascend
|
||||
```
|
||||
|
||||
- eval parameters
|
||||
- Parameters that can be modified at the terminal
|
||||
|
||||
```bash
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--checkpoint_path CHECKPOINT_PATH
|
||||
Checkpoint file path
|
||||
--dataset_path DATASET_PATH
|
||||
Dataset path
|
||||
--device_target DEVICE_TARGET
|
||||
Ascend or GPU. Default: Ascend
|
||||
# Train
|
||||
train_data_dir: '' # train dataset path
|
||||
ckpt_path: 'ckpts' # the folder path to save '*.ckpt' files. Relative path.
|
||||
eval_file_name: "./auc.log" # file path to record accuracy
|
||||
loss_file_name: "./loss.log" # file path to record loss
|
||||
do_eval: "True" # whether do eval while training, default is 'True'.
|
||||
# Test
|
||||
test_data_dir: '' # test dataset path
|
||||
checkpoint_path: '' # the path of the weight file to be evaluated relative to the file `eval.py`, and the weight file must be included in the code directory.
|
||||
# Export
|
||||
batch_size: 16000 # batch_size for exported model.
|
||||
ckpt_file: '' # the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.
|
||||
file_name: "autodis" # output file name.
|
||||
file_format: "AIR" # output file format, you can choose from AIR or MINDIR, default is AIR"
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
@ -144,7 +186,7 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```python
|
||||
python train.py \
|
||||
--dataset_path='dataset/train' \
|
||||
--train_data_dir='dataset/train' \
|
||||
--ckpt_path='./checkpoint' \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
|
@ -174,11 +216,11 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
```python
|
||||
python eval.py \
|
||||
--dataset_path='dataset/test' \
|
||||
--test_data_dir='dataset/test' \
|
||||
--checkpoint_path='./checkpoint/autodis.ckpt' \
|
||||
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
|
||||
OR
|
||||
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
|
||||
sh scripts/run_eval.sh 0 Ascend /test_data_dir /checkpoint_path/autodis.ckpt
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
|
||||
|
@ -191,12 +233,37 @@ Parameters for both training and evaluation can be set in config.py
|
|||
|
||||
### [Export MindIR](#contents)
|
||||
|
||||
```shell
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
- Export on local
|
||||
|
||||
The ckpt_file parameter is required,
|
||||
`file_format` should be in ["AIR", "MINDIR"]
|
||||
```shell
|
||||
# The ckpt_file parameter is required, `EXPORT_FORMAT` should be in ["AIR", "MINDIR"]
|
||||
python export.py --ckpt_file [CKPT_PATH] --file_name [FILE_NAME] --file_format [FILE_FORMAT]
|
||||
```
|
||||
|
||||
- Export on ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start as follows)
|
||||
|
||||
```python
|
||||
# (1) Upload the code folder to S3 bucket.
|
||||
# (2) Click to "create training task" on the website UI interface.
|
||||
# (3) Set the code directory to "/{path}/autodis" on the website UI interface.
|
||||
# (4) Set the startup file to /{path}/autodis/export.py" on the website UI interface.
|
||||
# (5) Perform a or b.
|
||||
# a. setting parameters in /{path}/autodis/default_config.yaml.
|
||||
# 1. Set ”enable_modelarts: True“
|
||||
# 2. Set “ckpt_file: ./{path}/*.ckpt”('ckpt_file' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Set ”file_name: autodis“
|
||||
# 4. Set ”file_format:AIR“(you can choose from AIR or MINDIR)
|
||||
# b. adding on the website UI interface.
|
||||
# 1. Add ”enable_modelarts=True“
|
||||
# 2. Add “ckpt_file=./{path}/*.ckpt”('ckpt_file' indicates the path of the weight file to be exported relative to the file `export.py`, and the weight file must be included in the code directory.)
|
||||
# 3. Add ”file_name=autodis“
|
||||
# 4. Add ”file_format=AIR“(you can choose from AIR or MINDIR)
|
||||
# (7) Check the "data storage location" on the website UI interface and set the "Dataset path" path (This step is useless, but necessary.).
|
||||
# (8) Set the "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (9) Under the item "resource pool selection", select the specification of a single card.
|
||||
# (10) Create your job.
|
||||
# You will see autodis.air under "Output file path".
|
||||
```
|
||||
|
||||
### Infer on Ascend310
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# Url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# Path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "Ascend"
|
||||
enable_profiling: False
|
||||
# ==============================================================================
|
||||
# Parameters that can be modified at the terminal
|
||||
# Train
|
||||
train_data_dir: ''
|
||||
ckpt_path: 'ckpts'
|
||||
eval_file_name: "./auc.log"
|
||||
loss_file_name: "./loss.log"
|
||||
do_eval: "True"
|
||||
# Test
|
||||
test_data_dir: ''
|
||||
checkpoint_path: ''
|
||||
# Export
|
||||
batch_size: 16000
|
||||
ckpt_file: ''
|
||||
file_name: "autodis"
|
||||
file_format: "AIR"
|
||||
# Dataset related
|
||||
DataConfig:
|
||||
data_vocab_size: 184965
|
||||
train_num_of_parts: 21
|
||||
test_num_of_parts: 3
|
||||
batch_size: 1000
|
||||
data_field_size: 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format: 2
|
||||
# Model related
|
||||
ModelConfig:
|
||||
batch_size: DataConfig.batch_size
|
||||
data_field_size: DataConfig.data_field_size
|
||||
data_vocab_size: DataConfig.data_vocab_size
|
||||
data_emb_dim: 80
|
||||
deep_layer_args: [[400, 400, 512], "relu"]
|
||||
init_args: [-0.01, 0.01]
|
||||
weight_bias_init: ['normal', 'normal']
|
||||
keep_prob: 0.9
|
||||
split_index: 13
|
||||
hash_size: 20
|
||||
temperature: 0.00001 # 1e-5
|
||||
# Training related
|
||||
TrainConfig:
|
||||
batch_size: DataConfig.batch_size
|
||||
l2_coef: 0.000001 # 1e-6
|
||||
learning_rate: 0.00001 # 1e-5
|
||||
epsilon: 0.00000001 # 1e-8
|
||||
loss_scale: 1024.0
|
||||
train_epochs: 15
|
||||
save_checkpoint: True
|
||||
ckpt_file_name_prefix: "autodis"
|
||||
save_checkpoint_steps: 1
|
||||
keep_checkpoint_max: 15
|
||||
eval_callback: True
|
||||
loss_callback: True
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
# Parameters that been used on ModelArts
|
||||
enable_modelarts: "Whether training on modelarts, default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of the input data."
|
||||
output_path: "The location of the output file."
|
||||
device_target: "Running platform, choose from Ascend, and default is Ascend."
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
# Parameters that can be modified at the terminal
|
||||
train_data_dir: "Train dataset dir, default is None"
|
||||
ckpt_path: "ckpt dir to save, default is None"
|
||||
eval_file_name: "Loss log file path. Default: './loss.log'"
|
||||
loss_file_name: "Loss log file path. Default: './loss.log'"
|
||||
do_eval: 'Do evaluation or not, only support "True" or "False". Default: "True"'
|
||||
test_data_dir: "Test dataset dir, default is None"
|
||||
checkpoint_path: " Relative path of '*.ckpt' to be evaluated relative to the eval.py"
|
||||
ckpt_file: "Checkpoint file path."
|
||||
file_name: "Output file name."
|
||||
file_format: "Output file format, you can choose from AIR or MINDIR, default is AIR"
|
||||
---
|
||||
#Choices
|
||||
device_target: ["Ascend"]
|
||||
do_eval: ["True", "False"]
|
||||
file_format: ["AIR", "MINDIR"]
|
|
@ -16,47 +16,43 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.autodis import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.model_utils.config import config, data_config, model_config, train_config
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
|
||||
help='Default: Ascend')
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
|
||||
|
||||
def add_write(file_path, print_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(print_str + '\n')
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.test_data_dir = config.data_path
|
||||
config.checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), config.checkpoint_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_eval():
|
||||
'''eval function'''
|
||||
device_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
ds_eval = create_dataset(config.test_data_dir, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
train_net.set_train()
|
||||
eval_net.set_train(False)
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
param_dict = load_checkpoint(config.checkpoint_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
start = time.time()
|
||||
|
@ -66,3 +62,6 @@ if __name__ == '__main__':
|
|||
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
|
||||
print(out_str)
|
||||
add_write('./auc.log', str(out_str))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_eval()
|
||||
|
|
|
@ -13,38 +13,42 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export ckpt to model"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint
|
||||
|
||||
from src.autodis import ModelBuilder
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.model_utils.config import config, data_config, model_config, train_config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(description="autodis export")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="autodis", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend"], default="Ascend",
|
||||
help="device target")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.file_name = os.path.join(config.output_path, config.file_name)
|
||||
config.ckpt_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), config.ckpt_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_config = DataConfig()
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_export():
|
||||
'''export checkpoint file into air/mindir'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
_, network = model_builder.get_train_eval_net()
|
||||
network.set_train(False)
|
||||
|
||||
load_checkpoint(args.ckpt_file, net=network)
|
||||
load_checkpoint(config.ckpt_file, net=network)
|
||||
|
||||
batch_ids = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.int32))
|
||||
batch_wts = Tensor(np.zeros([data_config.batch_size, data_config.data_field_size]).astype(np.float32))
|
||||
labels = Tensor(np.zeros([data_config.batch_size, 1]).astype(np.float32))
|
||||
|
||||
input_data = [batch_ids, batch_wts, labels]
|
||||
export(network, *input_data, file_name=args.file_name, file_format=args.file_format)
|
||||
export(network, *input_data, file_name=config.file_name, file_format=config.file_format)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_export()
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.autodis import ModelBuilder
|
||||
from src.config import ModelConfig, TrainConfig
|
||||
from src.model_utils.config import model_config, train_config
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'autodis':
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
_, autodis_eval_net = model_builder.get_train_eval_net()
|
||||
return autodis_eval_net
|
||||
|
|
|
@ -18,7 +18,7 @@ import argparse
|
|||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from src.autodis import AUCMetric
|
||||
from src.config import TrainConfig
|
||||
from src.model_utils.config import train_config
|
||||
|
||||
parser = argparse.ArgumentParser(description='postprocess')
|
||||
parser.add_argument('--result_path', type=str, default="./result_Files", help='result path')
|
||||
|
@ -28,7 +28,6 @@ args_opt, _ = parser.parse_known_args()
|
|||
def get_acc():
|
||||
''' get accuracy '''
|
||||
auc_metric = AUCMetric()
|
||||
train_config = TrainConfig()
|
||||
files = os.listdir(args_opt.label_path)
|
||||
batch_size = train_config.batch_size
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
from src.config import DataConfig, TrainConfig
|
||||
from src.model_utils.config import data_config, train_config
|
||||
from src.dataset import create_dataset, DataType
|
||||
|
||||
parser = argparse.ArgumentParser(description='preprocess.')
|
||||
|
@ -24,11 +24,9 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
|||
parser.add_argument('--result_path', type=str, default='./preprocess_Result', help='Result path')
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
def generate_bin():
|
||||
'''generate bin files'''
|
||||
data_config = DataConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
ds = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
|
@ -53,5 +51,6 @@ def generate_bin():
|
|||
|
||||
print("=" * 20, "export bin files finished", "=" * 20)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_bin()
|
||||
|
|
|
@ -14,14 +14,34 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH"
|
||||
echo "sh scripts/run_eval.sh [DEVICE_ID] [DEVICE_TARGET] [TEST_DATA_DIR] [CHECKPOINT_PATH]"
|
||||
echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DEVICE_TARGET=$2
|
||||
DATA_URL=$3
|
||||
CHECKPOINT_PATH=$4
|
||||
DATA_URL=$(readlink -f "$3")
|
||||
CHECKPOINT_PATH=$(readlink -f "$4")
|
||||
|
||||
DEVICE_TARGET=$2
|
||||
if [ "$DEVICE_TARGET" = "GPU" ]; then
|
||||
export CUDA_VISIBLE_DEVICES=$1
|
||||
elif [ "$DEVICE_TARGET" = "Ascend" ]; then
|
||||
export DEVICE_ID=$1
|
||||
else
|
||||
echo "Unsupported platform:$DEVICE_TARGET"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
abs_path=$(readlink -f "$0")
|
||||
cur_path=$(dirname $abs_path)
|
||||
cd $cur_path
|
||||
|
||||
rm -rf ./eval_$DEVICE_TARGET
|
||||
mkdir ./eval_$DEVICE_TARGET
|
||||
cp ../eval.py ./eval_$DEVICE_TARGET
|
||||
cp ../*.yaml ./eval_$DEVICE_TARGET
|
||||
cp -r ../src ./eval_$DEVICE_TARGET
|
||||
cd ./eval_$DEVICE_TARGET || exit
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
|
@ -29,6 +49,6 @@ export GLOG_log_dir=${CUR_DIR}/ms_log
|
|||
export GLOG_logtostderr=0
|
||||
|
||||
python -u eval.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--test_data_dir=$DATA_URL \
|
||||
--checkpoint_path=$CHECKPOINT_PATH \
|
||||
--device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 &
|
||||
|
|
|
@ -14,31 +14,41 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_standalone_train.sh DEVICE_ID/CUDA_VISIBLE_DEVICES DEVICE_TARGET DATASET_PATH"
|
||||
echo "sh scripts/run_standalone_train.sh [DEVICE_ID/CUDA_VISIBLE_DEVICES] [DEVICE_TARGET] [TRAIN_DATA_DIR]"
|
||||
echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
|
||||
|
||||
DEVICE_TARGET=$2
|
||||
|
||||
if [ "$DEVICE_TARGET" = "GPU" ]
|
||||
then
|
||||
if [ "$DEVICE_TARGET" = "GPU" ]; then
|
||||
export CUDA_VISIBLE_DEVICES=$1
|
||||
fi
|
||||
|
||||
if [ "$DEVICE_TARGET" = "Ascend" ]
|
||||
then
|
||||
elif [ "$DEVICE_TARGET" = "Ascend" ]; then
|
||||
export DEVICE_ID=$1
|
||||
else
|
||||
echo "Unsupported platform:$DEVICE_TARGET"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_URL=$3
|
||||
DATA_URL=$(readlink -f "$3")
|
||||
|
||||
abs_path=$(readlink -f "$0")
|
||||
cur_path=$(dirname $abs_path)
|
||||
cd $cur_path
|
||||
|
||||
rm -rf ./train_single_$DEVICE_TARGET
|
||||
mkdir ./train_single_$DEVICE_TARGET
|
||||
cp ../train.py ./train_single_$DEVICE_TARGET
|
||||
cp ../*.yaml ./train_single_$DEVICE_TARGET
|
||||
cp -r ../src ./train_single_$DEVICE_TARGET
|
||||
cd ./train_single_$DEVICE_TARGET || exit
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
echo "Start train at platform:$DEVICE_TARGET, device_id:$DEVICE_ID"
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--train_data_dir=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
|
||||
class DataConfig:
|
||||
"""
|
||||
Define parameters of dataset.
|
||||
"""
|
||||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 1000
|
||||
data_field_size = 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format = 2
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Define parameters of model.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[400, 400, 512], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
split_index = 13
|
||||
hash_size = 20
|
||||
temperature = 1e-5
|
||||
|
||||
class TrainConfig:
|
||||
"""
|
||||
Define parameters of training.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 1e-6
|
||||
learning_rate = 1e-5
|
||||
epsilon = 1e-8
|
||||
loss_scale = 1024.0
|
||||
train_epochs = 15
|
||||
save_checkpoint = True
|
||||
ckpt_file_name_prefix = "autodis"
|
||||
save_checkpoint_steps = 1
|
||||
keep_checkpoint_max = 15
|
||||
eval_callback = True
|
||||
loss_callback = True
|
|
@ -24,7 +24,7 @@ import pandas as pd
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .config import DataConfig
|
||||
from src.model_utils.config import data_config as DataConfig
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
_config = "default_config.yaml"
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path=_config):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def extra_operations(cfg):
|
||||
"""
|
||||
Do extra work on Config object.
|
||||
|
||||
Args:
|
||||
cfg: Object after instantiation of class 'Config'.
|
||||
"""
|
||||
cfg.ModelConfig.batch_size = cfg.DataConfig.batch_size
|
||||
cfg.ModelConfig.data_field_size = cfg.DataConfig.data_field_size
|
||||
cfg.ModelConfig.data_vocab_size = cfg.DataConfig.data_vocab_size
|
||||
cfg.TrainConfig.batch_size = cfg.DataConfig.batch_size
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../{}".format(_config)),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
config_obj = Config(final_config)
|
||||
extra_operations(config_obj)
|
||||
return config_obj
|
||||
|
||||
|
||||
config = get_config()
|
||||
data_config = config.DataConfig
|
||||
model_config = config.ModelConfig
|
||||
train_config = config.TrainConfig
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(config)
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Device adapter for ModelArts"""
|
||||
|
||||
from src.model_utils.config import config
|
||||
|
||||
if config.enable_modelarts:
|
||||
from src.model_utils.moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
else:
|
||||
from src.model_utils.local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
|
||||
|
||||
__all__ = [
|
||||
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
|
||||
]
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Local adapter"""
|
||||
|
||||
import os
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
return "Local Job"
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Moxing adapter for ModelArts"""
|
||||
|
||||
import os
|
||||
import functools
|
||||
from mindspore import context
|
||||
from mindspore.profiler import Profiler
|
||||
from src.model_utils.config import config
|
||||
|
||||
_global_sync_count = 0
|
||||
|
||||
def get_device_id():
|
||||
device_id = os.getenv('DEVICE_ID', '0')
|
||||
return int(device_id)
|
||||
|
||||
|
||||
def get_device_num():
|
||||
device_num = os.getenv('RANK_SIZE', '1')
|
||||
return int(device_num)
|
||||
|
||||
|
||||
def get_rank_id():
|
||||
global_rank_id = os.getenv('RANK_ID', '0')
|
||||
return int(global_rank_id)
|
||||
|
||||
|
||||
def get_job_id():
|
||||
job_id = os.getenv('JOB_ID')
|
||||
job_id = job_id if job_id != "" else "default"
|
||||
return job_id
|
||||
|
||||
def sync_data(from_path, to_path):
|
||||
"""
|
||||
Download data from remote obs to local directory if the first url is remote url and the second one is local path
|
||||
Upload data from local directory to remote obs in contrast.
|
||||
"""
|
||||
import moxing as mox
|
||||
import time
|
||||
global _global_sync_count
|
||||
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
|
||||
_global_sync_count += 1
|
||||
|
||||
# Each server contains 8 devices as most.
|
||||
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
|
||||
print("from path: ", from_path)
|
||||
print("to path: ", to_path)
|
||||
mox.file.copy_parallel(from_path, to_path)
|
||||
print("===finish data synchronization===")
|
||||
try:
|
||||
os.mknod(sync_lock)
|
||||
# print("os.mknod({}) success".format(sync_lock))
|
||||
except IOError:
|
||||
pass
|
||||
print("===save flag===")
|
||||
|
||||
while True:
|
||||
if os.path.exists(sync_lock):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
print("Finish sync data from {} to {}.".format(from_path, to_path))
|
||||
|
||||
|
||||
def moxing_wrapper(pre_process=None, post_process=None):
|
||||
"""
|
||||
Moxing wrapper to download dataset and upload outputs.
|
||||
"""
|
||||
def wrapper(run_func):
|
||||
@functools.wraps(run_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# Download data from data_url
|
||||
if config.enable_modelarts:
|
||||
if config.data_url:
|
||||
sync_data(config.data_url, config.data_path)
|
||||
print("Dataset downloaded: ", os.listdir(config.data_path))
|
||||
if config.checkpoint_url:
|
||||
sync_data(config.checkpoint_url, config.load_path)
|
||||
print("Preload downloaded: ", os.listdir(config.load_path))
|
||||
if config.train_url:
|
||||
sync_data(config.train_url, config.output_path)
|
||||
print("Workspace downloaded: ", os.listdir(config.output_path))
|
||||
|
||||
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
|
||||
config.device_num = get_device_num()
|
||||
config.device_id = get_device_id()
|
||||
if not os.path.exists(config.output_path):
|
||||
os.makedirs(config.output_path)
|
||||
|
||||
if pre_process:
|
||||
pre_process()
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler = Profiler()
|
||||
|
||||
run_func(*args, **kwargs)
|
||||
|
||||
if config.enable_profiling:
|
||||
profiler.analyse()
|
||||
|
||||
# Upload data to train_url
|
||||
if config.enable_modelarts:
|
||||
if post_process:
|
||||
post_process()
|
||||
|
||||
if config.train_url:
|
||||
print("Start to copy output directory")
|
||||
sync_data(config.output_path, config.train_url)
|
||||
return wrapped_func
|
||||
return wrapper
|
|
@ -15,7 +15,6 @@
|
|||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -25,57 +24,50 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
|||
from mindspore.common import set_seed
|
||||
|
||||
from src.autodis import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.callback import EvalCallBack, LossCallBack
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config, train_config, data_config, model_config
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
|
||||
help='Auc log file path. Default: "./auc.log"')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
|
||||
help='Loss log file path. Default: "./loss.log"')
|
||||
parser.add_argument('--do_eval', type=str, default='True', choices=["True", "False"],
|
||||
help='Do evaluation or not, only support "True" or "False". Default: "True"')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
|
||||
help='Default: Ascend')
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
args_opt.do_eval = args_opt.do_eval == 'True'
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
config.train_data_dir = config.data_path
|
||||
config.ckpt_path = os.path.join(config.output_path, config.ckpt_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def run_train():
|
||||
'''train function'''
|
||||
config.do_eval = config.do_eval == 'True'
|
||||
rank_size = get_device_num()
|
||||
if rank_size > 1:
|
||||
if args_opt.device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
if config.device_target == "Ascend":
|
||||
device_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
||||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
rank_id = get_rank_id()
|
||||
else:
|
||||
print("Unsupported device_target ", args_opt.device_target)
|
||||
print("Unsupported device_target ", config.device_target)
|
||||
exit()
|
||||
else:
|
||||
if args_opt.device_target == "Ascend":
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
|
||||
if config.device_target == "Ascend":
|
||||
device_id = get_device_id()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
|
||||
else:
|
||||
print("Unsupported device_target ", args_opt.device_target)
|
||||
print("Unsupported device_target ", config.device_target)
|
||||
exit()
|
||||
rank_size = None
|
||||
rank_id = None
|
||||
|
||||
# Init Profiler
|
||||
|
||||
ds_train = create_dataset(args_opt.dataset_path,
|
||||
ds_train = create_dataset(config.train_data_dir,
|
||||
train_mode=True,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
|
@ -84,34 +76,36 @@ if __name__ == '__main__':
|
|||
rank_id=rank_id)
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
|
||||
steps_size = ds_train.get_dataset_size()
|
||||
# steps_size = ds_train.get_dataset_size()
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
model_builder = ModelBuilder(model_config, train_config)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
|
||||
loss_callback = LossCallBack(loss_file_path=config.loss_file_name)
|
||||
callback_list = [time_callback, loss_callback]
|
||||
|
||||
if train_config.save_checkpoint:
|
||||
if rank_size:
|
||||
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
|
||||
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
config.ckpt_path = os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||
directory=args_opt.ckpt_path,
|
||||
directory=config.ckpt_path,
|
||||
config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
|
||||
if args_opt.do_eval:
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
if config.do_eval:
|
||||
ds_eval = create_dataset(config.train_data_dir, train_mode=False,
|
||||
epochs=1,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
eval_file_path=args_opt.eval_file_name)
|
||||
eval_file_path=config.eval_file_name)
|
||||
callback_list.append(eval_callback)
|
||||
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
||||
if __name__ == '__main__':
|
||||
run_train()
|
||||
|
|
Loading…
Reference in New Issue