fix_modelzoo_widedeep_run_multinup_train

This commit is contained in:
yao_yf 2020-06-11 10:43:20 +08:00
parent 717169f99e
commit 41715677e7
9 changed files with 88 additions and 37 deletions

View File

@ -13,26 +13,28 @@ The Criteo datasets are used for model training and evaluation.
The entire code structure is as following:
```
|--- wide_and_deep/
train_and_test.py "Entrance of Wide&Deep model training and evaluation"
test.py "Entrance of Wide&Deep model evaluation"
train.py "Entrance of Wide&Deep model training"
train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
|--- src/ "entrance of training and evaluation"
config.py "parameters configuration"
dataset.py "Dataset loader class"
process_data.py "process dataset"
preprocess_data.py "pre_process dataset"
WideDeep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
|--- script/ "run shell dir"
run_multinpu_train.sh "run data parallel"
train_and_eval.py "Entrance of Wide&Deep model training and evaluation"
eval.py "Entrance of Wide&Deep model evaluation"
train.py "Entrance of Wide&Deep model training"
train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
train_and_eval_auto_parallel.py
|--- src/ "Entrance of training and evaluation"
config.py "Parameters configuration"
dataset.py "Dataset loader class"
process_data.py "Process dataset"
preprocess_data.py "Pre_process dataset"
wide_and_deep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
|--- script/ "Run shell dir"
run_multinpu_train.sh "Run data parallel"
run_auto_parallel_train.sh "Run auto parallel"
```
### Train and evaluate model
To train and evaluate the model, command as follows:
```
python train_and_test.py
python train_and_eval.py
```
Arguments:
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
@ -44,6 +46,7 @@ Arguments:
* `--emb_dim` The dense embedding dimension of sparse feature.
* `--deep_layers_dim` The dimension of all deep layers.
* `--deep_layers_act` The activation of all deep layers.
* `--dropout_flag` Whether do dropout.
* `--keep_prob` The rate to keep in dropout layer.
* `--ckpt_path`The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
@ -63,6 +66,7 @@ Arguments:
* `--emb_dim` The dense embedding dimension of sparse feature.
* `--deep_layers_dim` The dimension of all deep layers.
* `--deep_layers_act` The activation of all deep layers.
* `--dropout_flag` Whether do dropout.
* `--keep_prob` The rate to keep in dropout layer.
* `--ckpt_path`The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
@ -70,13 +74,17 @@ Arguments:
To train the model in distributed, command as follows:
```
# configure environment path, RANK_TABLE_FILE, RANK_SIZE, MINDSPORE_HCCL_CONFIG_PATH before training
bash run_multinpu_train.sh
# configure environment path before training
bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
```
```
# configure environment path before training
bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
```
To evaluate the model, command as follows:
```
python test.py
python eval.py
```
Arguments:
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.

View File

@ -0,0 +1,35 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# bash run_multinpu_train.sh
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export RANK_TABLE_FILE=$4
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
done

View File

@ -24,12 +24,12 @@ export DATASET=$3
export RANK_TABLE_FILE=$4
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<=$RANK_SIZE;i++));
for((i=0;i<$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
python -s ${self_path}/../train_and_eval_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
done

View File

@ -31,7 +31,7 @@ def argparse_init():
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
parser.add_argument("--deep_layer_act", type=str, default='relu')
parser.add_argument("--keep_prob", type=float, default=1.0)
parser.add_argument("--dropout_flag", type=int, default=0)
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
@ -86,7 +86,7 @@ class WideDeepConfig():
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.dropout_flag = bool(args.dropout_flag)
self.l2_coef = 8e-5
self.output_path = args.output_path

View File

@ -19,7 +19,7 @@ import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
# from mindspore.nn import Dropout
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
@ -82,7 +82,7 @@ class DenseLayer(nn.Cell):
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
keep_prob=0.7, scale_coef=1.0, convert_dtype=True):
keep_prob=0.7, scale_coef=1.0, convert_dtype=True, drop_out=False):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(
@ -92,11 +92,12 @@ class DenseLayer(nn.Cell):
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
#self.dropout = Dropout(keep_prob=keep_prob)
self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
self.convert_dtype = convert_dtype
self.drop_out = drop_out
def _init_activation(self, act_str):
act_str = act_str.lower()
@ -110,8 +111,8 @@ class DenseLayer(nn.Cell):
def construct(self, x):
x = self.act_func(x)
# if self.training:
# x = self.dropout(x)
if self.training and self.drop_out:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
@ -163,23 +164,28 @@ class WideDeepModel(nn.Cell):
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
self.all_dim_list[1],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.deep_layer_act,
convert_dtype=True, drop_out=config.dropout_flag)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
self.all_dim_list[2],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.deep_layer_act,
convert_dtype=True, drop_out=config.dropout_flag)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
self.all_dim_list[3],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.deep_layer_act,
convert_dtype=True, drop_out=config.dropout_flag)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
self.all_dim_list[4],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.deep_layer_act,
convert_dtype=True, drop_out=config.dropout_flag)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
self.all_dim_list[5],
self.weight_bias_init,
self.deep_layer_act, convert_dtype=True)
self.deep_layer_act,
convert_dtype=True, drop_out=config.dropout_flag)
self.gather_v2 = P.GatherV2()
self.mul = P.Mul()

View File

@ -71,11 +71,10 @@ class ModelBuilder():
return get_WideDeep_net(config)
def test_train_eval():
def train_and_eval(config):
"""
test_train_eval
"""
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
@ -109,9 +108,12 @@ def test_train_eval():
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
if __name__ == "__main__":
test_train_eval()
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
train_and_eval(wide_deep_config)

View File

@ -66,7 +66,7 @@ class ModelBuilder():
return get_WideDeep_net(config)
def test_train_eval(config):
def train_and_eval(config):
"""
test_train_eval
"""
@ -105,4 +105,4 @@ def test_train_eval(config):
if __name__ == "__main__":
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
test_train_eval(wide_deep_config)
train_and_eval(wide_deep_config)