add gpu support for deepfm model

fixed pylint errors
This commit is contained in:
tom__chen 2020-08-18 08:29:38 -08:00
parent 0e27a04da1
commit 183ae5d009
5 changed files with 86 additions and 21 deletions

View File

@ -24,16 +24,16 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset
from src.dataset import create_dataset, DataType
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU')
args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
def add_write(file_path, print_str):
@ -47,7 +47,8 @@ if __name__ == '__main__':
train_config = TrainConfig()
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size)
epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train()

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH"
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in log/output.log"
export RANK_SIZE=$1
DATA_URL=$2
rm -rf log
mkdir ./log
cp *.py ./log
cp -r src ./log
cd ./log || exit
env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE \
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='GPU' \
--do_eval=True > output.log 2>&1 &

View File

@ -14,13 +14,14 @@
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH"
echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path"
echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH"
echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
export DEVICE_ID=$1
DATA_URL=$2
CHECKPOINT_PATH=$3
DEVICE_TARGET=$2
DATA_URL=$3
CHECKPOINT_PATH=$4
mkdir -p ms_log
CUR_DIR=`pwd`
@ -29,4 +30,5 @@ export GLOG_logtostderr=0
python -u eval.py \
--dataset_path=$DATA_URL \
--checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 &
--checkpoint_path=$CHECKPOINT_PATH \
--device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 &

View File

@ -14,12 +14,13 @@
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH"
echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path"
echo "sh scripts/run_standalone_train.sh DEVICE_ID DEVICE_TARGET DATASET_PATH"
echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
export DEVICE_ID=$1
DATA_URL=$2
DEVICE_TARGET=$2
DATA_URL=$3
mkdir -p ms_log
CUR_DIR=`pwd`
@ -31,4 +32,5 @@ python -u train.py \
--ckpt_path="checkpoint" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target=$DEVICE_TARGET \
--do_eval=True > ms_log/output.log 2>&1 &

View File

@ -16,11 +16,14 @@
import os
import sys
import argparse
import random
import numpy as np
from mindspore import context, ParallelMode
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
import mindspore.dataset.engine as de
from src.deepfm import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
@ -34,24 +37,41 @@ parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path
parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path')
parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path')
parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend, GPU, or CPU')
args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
rank_size = int(os.environ.get("RANK_SIZE", 1))
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
rank_id = int(os.environ.get('RANK_ID'))
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
rank_id = int(os.environ.get('RANK_ID'))
elif args_opt.device_target == "GPU":
init("nccl")
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
rank_id = get_rank()
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
else:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
rank_size = None
rank_id = None
@ -73,6 +93,8 @@ if __name__ == '__main__':
callback_list = [time_callback, loss_callback]
if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,