forked from mindspore-Ecosystem/mindspore
fix_modelzoo_widedeep_run_multinup_train
This commit is contained in:
parent
717169f99e
commit
41715677e7
|
@ -13,26 +13,28 @@ The Criteo datasets are used for model training and evaluation.
|
|||
The entire code structure is as following:
|
||||
```
|
||||
|--- wide_and_deep/
|
||||
train_and_test.py "Entrance of Wide&Deep model training and evaluation"
|
||||
test.py "Entrance of Wide&Deep model evaluation"
|
||||
train.py "Entrance of Wide&Deep model training"
|
||||
train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
|
||||
|--- src/ "entrance of training and evaluation"
|
||||
config.py "parameters configuration"
|
||||
dataset.py "Dataset loader class"
|
||||
process_data.py "process dataset"
|
||||
preprocess_data.py "pre_process dataset"
|
||||
WideDeep.py "Model structure"
|
||||
callbacks.py "Callback class for training and evaluation"
|
||||
metrics.py "Metric class"
|
||||
|--- script/ "run shell dir"
|
||||
run_multinpu_train.sh "run data parallel"
|
||||
train_and_eval.py "Entrance of Wide&Deep model training and evaluation"
|
||||
eval.py "Entrance of Wide&Deep model evaluation"
|
||||
train.py "Entrance of Wide&Deep model training"
|
||||
train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
|
||||
train_and_eval_auto_parallel.py
|
||||
|--- src/ "Entrance of training and evaluation"
|
||||
config.py "Parameters configuration"
|
||||
dataset.py "Dataset loader class"
|
||||
process_data.py "Process dataset"
|
||||
preprocess_data.py "Pre_process dataset"
|
||||
wide_and_deep.py "Model structure"
|
||||
callbacks.py "Callback class for training and evaluation"
|
||||
metrics.py "Metric class"
|
||||
|--- script/ "Run shell dir"
|
||||
run_multinpu_train.sh "Run data parallel"
|
||||
run_auto_parallel_train.sh "Run auto parallel"
|
||||
```
|
||||
|
||||
### Train and evaluate model
|
||||
To train and evaluate the model, command as follows:
|
||||
```
|
||||
python train_and_test.py
|
||||
python train_and_eval.py
|
||||
```
|
||||
Arguments:
|
||||
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
|
||||
|
@ -44,6 +46,7 @@ Arguments:
|
|||
* `--emb_dim`: The dense embedding dimension of sparse feature.
|
||||
* `--deep_layers_dim`: The dimension of all deep layers.
|
||||
* `--deep_layers_act`: The activation of all deep layers.
|
||||
* `--dropout_flag`: Whether do dropout.
|
||||
* `--keep_prob`: The rate to keep in dropout layer.
|
||||
* `--ckpt_path`:The location of the checkpoint file.
|
||||
* `--eval_file_name` : Eval output file.
|
||||
|
@ -63,6 +66,7 @@ Arguments:
|
|||
* `--emb_dim`: The dense embedding dimension of sparse feature.
|
||||
* `--deep_layers_dim`: The dimension of all deep layers.
|
||||
* `--deep_layers_act`: The activation of all deep layers.
|
||||
* `--dropout_flag`: Whether do dropout.
|
||||
* `--keep_prob`: The rate to keep in dropout layer.
|
||||
* `--ckpt_path`:The location of the checkpoint file.
|
||||
* `--eval_file_name` : Eval output file.
|
||||
|
@ -70,13 +74,17 @@ Arguments:
|
|||
|
||||
To train the model in distributed, command as follows:
|
||||
```
|
||||
# configure environment path, RANK_TABLE_FILE, RANK_SIZE, MINDSPORE_HCCL_CONFIG_PATH before training
|
||||
bash run_multinpu_train.sh
|
||||
# configure environment path before training
|
||||
bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
||||
```
|
||||
```
|
||||
# configure environment path before training
|
||||
bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
||||
```
|
||||
|
||||
To evaluate the model, command as follows:
|
||||
```
|
||||
python test.py
|
||||
python eval.py
|
||||
```
|
||||
Arguments:
|
||||
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# bash run_multinpu_train.sh
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=$1
|
||||
export EPOCH_SIZE=$2
|
||||
export DATASET=$3
|
||||
export RANK_TABLE_FILE=$4
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$4
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
|
||||
done
|
|
@ -24,12 +24,12 @@ export DATASET=$3
|
|||
export RANK_TABLE_FILE=$4
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$4
|
||||
|
||||
for((i=0;i<=$RANK_SIZE;i++));
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
|
||||
python -s ${self_path}/../train_and_eval_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
|
||||
done
|
||||
|
|
|
@ -31,7 +31,7 @@ def argparse_init():
|
|||
parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128])
|
||||
parser.add_argument("--deep_layer_act", type=str, default='relu')
|
||||
parser.add_argument("--keep_prob", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--dropout_flag", type=int, default=0)
|
||||
parser.add_argument("--output_path", type=str, default="./output/")
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log")
|
||||
|
@ -86,7 +86,7 @@ class WideDeepConfig():
|
|||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = False
|
||||
self.dropout_flag = bool(args.dropout_flag)
|
||||
self.l2_coef = 8e-5
|
||||
|
||||
self.output_path = args.output_path
|
||||
|
|
|
@ -19,7 +19,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
# from mindspore.nn import Dropout
|
||||
from mindspore.nn import Dropout
|
||||
from mindspore.nn.optim import Adam, FTRL
|
||||
# from mindspore.nn.metrics import Metric
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
|
@ -82,7 +82,7 @@ class DenseLayer(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
|
||||
keep_prob=0.7, scale_coef=1.0, convert_dtype=True):
|
||||
keep_prob=0.7, scale_coef=1.0, convert_dtype=True, drop_out=False):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(
|
||||
|
@ -92,11 +92,12 @@ class DenseLayer(nn.Cell):
|
|||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
#self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
self.convert_dtype = convert_dtype
|
||||
self.drop_out = drop_out
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_str = act_str.lower()
|
||||
|
@ -110,8 +111,8 @@ class DenseLayer(nn.Cell):
|
|||
|
||||
def construct(self, x):
|
||||
x = self.act_func(x)
|
||||
# if self.training:
|
||||
# x = self.dropout(x)
|
||||
if self.training and self.drop_out:
|
||||
x = self.dropout(x)
|
||||
x = self.mul(x, self.scale_coef)
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
|
@ -163,23 +164,28 @@ class WideDeepModel(nn.Cell):
|
|||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
|
||||
self.all_dim_list[1],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True, drop_out=config.dropout_flag)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
|
||||
self.all_dim_list[2],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True, drop_out=config.dropout_flag)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
|
||||
self.all_dim_list[3],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True, drop_out=config.dropout_flag)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
|
||||
self.all_dim_list[4],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True, drop_out=config.dropout_flag)
|
||||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
|
||||
self.all_dim_list[5],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act, convert_dtype=True)
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True, drop_out=config.dropout_flag)
|
||||
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.mul = P.Mul()
|
||||
|
|
|
@ -71,11 +71,10 @@ class ModelBuilder():
|
|||
return get_WideDeep_net(config)
|
||||
|
||||
|
||||
def test_train_eval():
|
||||
def train_and_eval(config):
|
||||
"""
|
||||
test_train_eval
|
||||
"""
|
||||
config = WideDeepConfig()
|
||||
data_path = config.data_path
|
||||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
|
@ -109,9 +108,12 @@ def test_train_eval():
|
|||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
|
||||
model.train(epochs, ds_train,
|
||||
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_train_eval()
|
||||
wide_deep_config = WideDeepConfig()
|
||||
wide_deep_config.argparse_init()
|
||||
train_and_eval(wide_deep_config)
|
|
@ -66,7 +66,7 @@ class ModelBuilder():
|
|||
return get_WideDeep_net(config)
|
||||
|
||||
|
||||
def test_train_eval(config):
|
||||
def train_and_eval(config):
|
||||
"""
|
||||
test_train_eval
|
||||
"""
|
||||
|
@ -105,4 +105,4 @@ def test_train_eval(config):
|
|||
if __name__ == "__main__":
|
||||
wide_deep_config = WideDeepConfig()
|
||||
wide_deep_config.argparse_init()
|
||||
test_train_eval(wide_deep_config)
|
||||
train_and_eval(wide_deep_config)
|
Loading…
Reference in New Issue