modezoo wide&deep run clusters

This commit is contained in:
yao_yf 2020-07-16 21:12:06 +08:00
parent b5d8dad47d
commit 3dbe872596
25 changed files with 405 additions and 50 deletions

View File

@ -97,7 +97,7 @@ class ReshapeInfo : public OperatorInfo {
TensorLayout output_layout_;
bool input_layout_set_flag_;
bool output_layout_set_flag_;
bool is_generating_costs_;
bool is_generating_costs_ = false;
bool is_skip_ = false;
std::string pre_operator_name_;
std::string next_operator_name_;

View File

@ -16,7 +16,7 @@ Arguments:
* `--data_path`: Dataset storage path (Default: ./criteo_data/).
## Dataset
The Criteo datasets are used for model training and evaluation.
The common used benchmark datasets are used for model training and evaluation.
## Running Code
@ -63,6 +63,7 @@ Arguments:
* `--ckpt_path`The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.
* `--dataset_type` : tfrecord/mindrecord/hd5.
To train the model in one device, command as follows:
```
@ -84,6 +85,7 @@ Arguments:
* `--ckpt_path`The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.
* `--dataset_type` : tfrecord/mindrecord/hd5.
To train the model in distributed, command as follows:
```
@ -95,6 +97,19 @@ bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
```
To train the model in clusters, command as follows:'''
```
# deploy wide&deep script in clusters
# CLUSTER_CONFIG is a json file, the sample is in script/.
# EXECUTE_PATH is the scripts path after the deploy.
bash deploy_cluster.sh CLUSTER_CONFIG_PATH EXECUTE_PATH
# enter EXECUTE_PATH, and execute start_cluster.sh as follows.
# MODE: "host_device_mix"
bash start_cluster.sh CLUSTER_CONFIG_PATH EPOCH_SIZE VOCAB_SIZE EMB_DIM
DATASET ENV_SH RANK_TABLE_FILE MODE
```
To evaluate the model, command as follows:
```
python eval.py

View File

@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
@ -69,8 +69,14 @@ def test_eval(config):
"""
data_path = config.data_path
batch_size = config.batch_size
ds_eval = create_dataset(data_path, train_mode=False, epochs=2,
batch_size=batch_size)
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, data_type=dataset_type)
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()

View File

@ -0,0 +1,21 @@
{
"rank_size": 32,
"cluster": {
"xx.xx.xx.xx": {
"user": "",
"passwd": ""
},
"xx.xx.xx.xx": {
"user": "",
"passwd": ""
},
"xx.xx.xx.xx": {
"user": "",
"passwd": ""
},
"xx.xx.xx.xx": {
"user": "",
"passwd": ""
}
}
}

View File

@ -0,0 +1,95 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
SSH="ssh -o StrictHostKeyChecking=no"
SCP="scp -o StrictHostKeyChecking=no"
error_msg()
{
local msg="$*"
echo "[ERROR]: $msg" 1>&2
exit 1
}
ssh_pass()
{
local node="$1"
local user="$2"
local passwd="$3"
shift 3
local cmd="$*"
sshpass -p "${passwd}" ${SSH} "${user}"@"${node}" ${cmd}
}
scp_pass()
{
local node="$1"
local user="$2"
local passwd="$3"
local src="$4"
local target="$5"
sshpass -p "${passwd}" ${SCP} -r "${src}" "${user}"@"${node}":"${target}"
}
rscp_pass()
{
local node="$1"
local user="$2"
local passwd="$3"
local src="$4"
local target="$5"
sshpass -p "${passwd}" ${SCP} -r "${user}"@"${node}":"${src}" "${target}"
}
get_rank_size()
{
local cluster_config=$1
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["rank_size"])'
}
get_train_dataset()
{
local cluster_config=$1
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["train_dataset"])'
}
get_cluster_list()
{
local cluster_config=$1
cat ${cluster_config} | python3 -c 'import sys,json;[print(node) for node in json.load(sys.stdin)["cluster"].keys()]' | sort
}
get_node_user()
{
local cluster_config=$1
local node=$2
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["user"])'
}
get_node_passwd()
{
local cluster_config=$1
local node=$2
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["passwd"])'
}
rsync_sshpass()
{
local node=$1
local user="$2"
local passwd="$3"
scp_pass "${node}" "${user}" "${passwd}" /usr/local/bin/sshpass /usr/local/bin/sshpass
}

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
SCRIPTPATH="$( cd "$(dirname "$0")" || exit ; pwd -P )"
# shellcheck source=/dev/null
source $SCRIPTPATH/common.sh
cluster_config_path=$1
execute_path=$2
RANK_SIZE=$(get_rank_size ${cluster_config_path})
RANK_START=0
node_list=$(get_cluster_list ${cluster_config_path})
for node in ${node_list}
do
user=$(get_node_user ${cluster_config_path} ${node})
passwd=$(get_node_passwd ${cluster_config_path} ${node})
echo "------------------${user}@${node}---------------------"
ssh_pass ${node} ${user} ${passwd} "rm -rf ${execute_path}"
scp_pass ${node} ${user} ${passwd} $SCRIPTPATH/../../wide_and_deep ${execute_path}
RANK_START=$[RANK_START+8]
if [[ $RANK_START -ge $RANK_SIZE ]]; then
break;
fi
done

View File

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
echo ${execute_path}
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
echo ${self_path}
export RANK_SIZE=$1
RANK_START=$2
EPOCH_SIZE=$3
VOCAB_SIZE=$4
EMB_DIM=$5
DATASET=$6
ENV_SH=$7
MODE=$8
export MINDSPORE_HCCL_CONFIG=$9
export RANK_TABLE_FILE=$9
DEVICE_START=0
# shellcheck source=/dev/null
source $ENV_SH
for((i=0;i<=7;i++));
do
export RANK_ID=$[i+RANK_START]
export DEVICE_ID=$[i+DEVICE_START]
rm -rf ${execute_path}/device_$RANK_ID
mkdir ${execute_path}/device_$RANK_ID
cd ${execute_path}/device_$RANK_ID || exit
if [ $MODE == "host_device_mix" ]; then
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 &
else
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 &
fi
done

View File

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
execute_path=$(pwd)
echo ${execute_path}
script_self=$(readlink -f "$0")
SCRIPTPATH=$(dirname "${script_self}")
echo ${SCRIPTPATH}
# shellcheck source=/dev/null
source $SCRIPTPATH/common.sh
cluster_config_path=$1
RANK_SIZE=$(get_rank_size ${cluster_config_path})
RANK_START=0
node_list=$(get_cluster_list ${cluster_config_path})
EPOCH_SIZE=$2
VOCAB_SIZE=$3
EMB_DIM=$4
DATASET=$5
MINDSPORE_HCCL_CONFIG_PATH=$6
ENV_SH=$7
MODE=$8
for node in ${node_list}
do
user=$(get_node_user ${cluster_config_path} ${node})
passwd=$(get_node_passwd ${cluster_config_path} ${node})
echo "------------------${user}@${node}---------------------"
if [ $MODE == "host_device_mix" ]; then
ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${MINDSPORE_HCCL_CONFIG_PATH}"
else
echo "[ERROR] mode is wrong"
exit 1
fi
RANK_START=$[RANK_START+8]
if [[ $RANK_START -ge $RANK_SIZE ]]; then
break;
fi
done

View File

@ -51,7 +51,7 @@ class LossCallBack(Callback):
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
@ -76,7 +76,7 @@ class EvalCallBack(Callback):
Args:
print_per_step (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_per_step must be int and >= 0.")
@ -87,6 +87,7 @@ class EvalCallBack(Callback):
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
self.eval_values = []
self.host_device_mix = host_device_mix
def epoch_end(self, run_context):
"""
@ -98,7 +99,7 @@ class EvalCallBack(Callback):
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
start_time = time.time()
out = self.model.eval(self.eval_dataset)
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
end_time = time.time()
eval_time = int(end_time - start_time)

View File

@ -38,6 +38,8 @@ def argparse_init():
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
parser.add_argument("--loss_file_name", type=str, default="loss.log")
parser.add_argument("--host_device_mix", type=int, default=0)
parser.add_argument("--dataset_type", type=str, default="tfrecord")
return parser
@ -68,6 +70,8 @@ class WideDeepConfig():
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.host_device_mix = 0
self.dataset_type = "tfrecord"
def argparse_init(self):
"""
@ -97,3 +101,5 @@ class WideDeepConfig():
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.host_device_mix = args.host_device_mix
self.dataset_type = args.dataset_type

View File

@ -20,7 +20,7 @@ from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL
from mindspore.nn.optim import Adam, FTRL, LazyAdam
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
@ -82,7 +82,7 @@ class DenseLayer(nn.Cell):
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False):
keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(
@ -137,8 +137,10 @@ class WideDeepModel(nn.Cell):
def __init__(self, config):
super(WideDeepModel, self).__init__()
self.batch_size = config.batch_size
host_device_mix = bool(config.host_device_mix)
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
self.batch_size = self.batch_size * get_group_size()
self.field_size = config.field_size
self.vocab_size = config.vocab_size
@ -187,16 +189,29 @@ class WideDeepModel(nn.Cell):
self.weight_bias_init,
self.deep_layer_act,
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
self.mul = P.Mul()
self.wide_mul = P.Mul()
self.deep_mul = P.Mul()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reshape = P.Reshape()
self.deep_reshape = P.Reshape()
self.square = P.Square()
self.shape = P.Shape()
self.tile = P.Tile()
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
if is_auto_parallel and host_device_mix:
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
self.deep_embeddinglookup = nn.EmbeddingLookup()
self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
self.wide_embeddinglookup = nn.EmbeddingLookup()
self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True)
self.reduce_sum.add_prim_attr("cross_batch", True)
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
def construct(self, id_hldr, wt_hldr):
"""
@ -206,13 +221,13 @@ class WideDeepModel(nn.Cell):
"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr)
wx = self.mul(wide_id_weight, mask)
wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr)
wx = self.wide_mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr)
vx = self.mul(deep_id_embs, mask)
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr)
vx = self.deep_mul(deep_id_embs, mask)
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
@ -233,19 +248,28 @@ class NetWithLossClass(nn.Cell):
def __init__(self, network, config):
super(NetWithLossClass, self).__init__(auto_prefix=False)
host_device_mix = bool(config.host_device_mix)
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.no_l2loss = host_device_mix and is_auto_parallel
self.network = network
self.l2_coef = config.l2_coef
self.loss = P.SigmoidCrossEntropyWithLogits()
self.square = P.Square()
self.reduceMean_false = P.ReduceMean(keep_dims=False)
if is_auto_parallel:
self.reduceMean_false.add_prim_attr("cross_batch", True)
self.reduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict, embedding_table = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
wide_loss = self.reduceMean_false(log_loss)
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
if self.no_l2loss:
deep_loss = wide_loss
else:
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
return wide_loss, deep_loss
@ -267,12 +291,15 @@ class TrainStepWrap(nn.Cell):
Append Adam and FTRL optimizers to the training network after that construct
function can be called to create the backward graph.
Args:
network (Cell): the training network. Note that loss function should have been added.
sens (Number): The adjust parameter. Default: 1000.0
network (Cell): The training network. Note that loss function should have been added.
sens (Number): The adjust parameter. Default: 1024.0
host_device_mix (Bool): Whether run in host and device mix mode. Default: False
"""
def __init__(self, network, sens=1024.0):
def __init__(self, network, sens=1024.0, host_device_mix=False):
super(TrainStepWrap, self).__init__()
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.network = network
self.network.set_train()
self.trainable_params = network.trainable_params()
@ -285,10 +312,19 @@ class TrainStepWrap(nn.Cell):
weights_d.append(params)
self.weights_w = ParameterTuple(weights_w)
self.weights_d = ParameterTuple(weights_d)
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0)
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
if host_device_mix and is_auto_parallel:
self.optimizer_d = LazyAdam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU")
self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU")
else:
self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
sens_param=True)

View File

@ -17,7 +17,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack
from src.datasets import create_dataset
from src.datasets import create_dataset, DataType
from src.config import WideDeepConfig
@ -63,7 +63,14 @@ def test_train(configure):
data_path = configure.data_path
batch_size = configure.batch_size
epochs = configure.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
if configure.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif configure.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
net_builder = ModelBuilder()

View File

@ -19,7 +19,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
@ -67,8 +67,16 @@ def test_train_eval(config):
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size)
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -27,13 +27,14 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
init()
@ -46,7 +47,7 @@ def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(loss_net)
train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix))
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
@ -81,19 +82,28 @@ def train_and_eval(config):
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
host_device_mix = bool(config.host_device_mix)
print("epochs is {}".format(epochs))
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
de.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size*get_group_size())
batch_size=batch_size*get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size*get_group_size())
batch_size=batch_size*get_group_size(), data_type=dataset_type)
else:
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
@ -105,18 +115,24 @@ def train_and_eval(config):
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if not host_device_mix:
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix))
if __name__ == "__main__":
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()
if wide_deep_config.host_device_mix == 1:
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
else:
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
train_and_eval(wide_deep_config)

View File

@ -25,7 +25,7 @@ from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.datasets import create_dataset, DataType
from src.metrics import AUCMetric
from src.config import WideDeepConfig
@ -73,11 +73,19 @@ def train_and_eval(config):
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
if config.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif config.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
dataset_type = DataType.H5
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
batch_size=batch_size, rank_id=get_rank(),
rank_size=get_group_size(), data_type=dataset_type)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))

View File

@ -21,10 +21,10 @@ export RANK_SIZE=$DEVICE_NUM
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
CODE_DIR="./"
if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep
elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep
if [ -d ${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep
elif [ -d ${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep ]; then
CODE_DIR=${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep
else
echo "[ERROR] code dir is not found"
fi