forked from OSSInnovation/mindspore
modezoo wide&deep run clusters
This commit is contained in:
parent
b5d8dad47d
commit
3dbe872596
|
@ -97,7 +97,7 @@ class ReshapeInfo : public OperatorInfo {
|
||||||
TensorLayout output_layout_;
|
TensorLayout output_layout_;
|
||||||
bool input_layout_set_flag_;
|
bool input_layout_set_flag_;
|
||||||
bool output_layout_set_flag_;
|
bool output_layout_set_flag_;
|
||||||
bool is_generating_costs_;
|
bool is_generating_costs_ = false;
|
||||||
bool is_skip_ = false;
|
bool is_skip_ = false;
|
||||||
std::string pre_operator_name_;
|
std::string pre_operator_name_;
|
||||||
std::string next_operator_name_;
|
std::string next_operator_name_;
|
||||||
|
|
|
@ -16,7 +16,7 @@ Arguments:
|
||||||
* `--data_path`: Dataset storage path (Default: ./criteo_data/).
|
* `--data_path`: Dataset storage path (Default: ./criteo_data/).
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
The Criteo datasets are used for model training and evaluation.
|
The common used benchmark datasets are used for model training and evaluation.
|
||||||
|
|
||||||
## Running Code
|
## Running Code
|
||||||
|
|
||||||
|
@ -63,6 +63,7 @@ Arguments:
|
||||||
* `--ckpt_path`:The location of the checkpoint file.
|
* `--ckpt_path`:The location of the checkpoint file.
|
||||||
* `--eval_file_name` : Eval output file.
|
* `--eval_file_name` : Eval output file.
|
||||||
* `--loss_file_name` : Loss output file.
|
* `--loss_file_name` : Loss output file.
|
||||||
|
* `--dataset_type` : tfrecord/mindrecord/hd5.
|
||||||
|
|
||||||
To train the model in one device, command as follows:
|
To train the model in one device, command as follows:
|
||||||
```
|
```
|
||||||
|
@ -84,6 +85,7 @@ Arguments:
|
||||||
* `--ckpt_path`:The location of the checkpoint file.
|
* `--ckpt_path`:The location of the checkpoint file.
|
||||||
* `--eval_file_name` : Eval output file.
|
* `--eval_file_name` : Eval output file.
|
||||||
* `--loss_file_name` : Loss output file.
|
* `--loss_file_name` : Loss output file.
|
||||||
|
* `--dataset_type` : tfrecord/mindrecord/hd5.
|
||||||
|
|
||||||
To train the model in distributed, command as follows:
|
To train the model in distributed, command as follows:
|
||||||
```
|
```
|
||||||
|
@ -95,6 +97,19 @@ bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
||||||
bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To train the model in clusters, command as follows:'''
|
||||||
|
```
|
||||||
|
# deploy wide&deep script in clusters
|
||||||
|
# CLUSTER_CONFIG is a json file, the sample is in script/.
|
||||||
|
# EXECUTE_PATH is the scripts path after the deploy.
|
||||||
|
bash deploy_cluster.sh CLUSTER_CONFIG_PATH EXECUTE_PATH
|
||||||
|
|
||||||
|
# enter EXECUTE_PATH, and execute start_cluster.sh as follows.
|
||||||
|
# MODE: "host_device_mix"
|
||||||
|
bash start_cluster.sh CLUSTER_CONFIG_PATH EPOCH_SIZE VOCAB_SIZE EMB_DIM
|
||||||
|
DATASET ENV_SH RANK_TABLE_FILE MODE
|
||||||
|
```
|
||||||
|
|
||||||
To evaluate the model, command as follows:
|
To evaluate the model, command as follows:
|
||||||
```
|
```
|
||||||
python eval.py
|
python eval.py
|
|
@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
from src.datasets import create_dataset
|
from src.datasets import create_dataset, DataType
|
||||||
from src.metrics import AUCMetric
|
from src.metrics import AUCMetric
|
||||||
from src.config import WideDeepConfig
|
from src.config import WideDeepConfig
|
||||||
|
|
||||||
|
@ -69,8 +69,14 @@ def test_eval(config):
|
||||||
"""
|
"""
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=2,
|
if config.dataset_type == "tfrecord":
|
||||||
batch_size=batch_size)
|
dataset_type = DataType.TFRECORD
|
||||||
|
elif config.dataset_type == "mindrecord":
|
||||||
|
dataset_type = DataType.MINDRECORD
|
||||||
|
else:
|
||||||
|
dataset_type = DataType.H5
|
||||||
|
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||||
|
batch_size=batch_size, data_type=dataset_type)
|
||||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
||||||
net_builder = ModelBuilder()
|
net_builder = ModelBuilder()
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"rank_size": 32,
|
||||||
|
"cluster": {
|
||||||
|
"xx.xx.xx.xx": {
|
||||||
|
"user": "",
|
||||||
|
"passwd": ""
|
||||||
|
},
|
||||||
|
"xx.xx.xx.xx": {
|
||||||
|
"user": "",
|
||||||
|
"passwd": ""
|
||||||
|
},
|
||||||
|
"xx.xx.xx.xx": {
|
||||||
|
"user": "",
|
||||||
|
"passwd": ""
|
||||||
|
},
|
||||||
|
"xx.xx.xx.xx": {
|
||||||
|
"user": "",
|
||||||
|
"passwd": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
SSH="ssh -o StrictHostKeyChecking=no"
|
||||||
|
SCP="scp -o StrictHostKeyChecking=no"
|
||||||
|
|
||||||
|
error_msg()
|
||||||
|
{
|
||||||
|
local msg="$*"
|
||||||
|
echo "[ERROR]: $msg" 1>&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
ssh_pass()
|
||||||
|
{
|
||||||
|
local node="$1"
|
||||||
|
local user="$2"
|
||||||
|
local passwd="$3"
|
||||||
|
shift 3
|
||||||
|
local cmd="$*"
|
||||||
|
sshpass -p "${passwd}" ${SSH} "${user}"@"${node}" ${cmd}
|
||||||
|
}
|
||||||
|
|
||||||
|
scp_pass()
|
||||||
|
{
|
||||||
|
local node="$1"
|
||||||
|
local user="$2"
|
||||||
|
local passwd="$3"
|
||||||
|
local src="$4"
|
||||||
|
local target="$5"
|
||||||
|
sshpass -p "${passwd}" ${SCP} -r "${src}" "${user}"@"${node}":"${target}"
|
||||||
|
}
|
||||||
|
|
||||||
|
rscp_pass()
|
||||||
|
{
|
||||||
|
local node="$1"
|
||||||
|
local user="$2"
|
||||||
|
local passwd="$3"
|
||||||
|
local src="$4"
|
||||||
|
local target="$5"
|
||||||
|
sshpass -p "${passwd}" ${SCP} -r "${user}"@"${node}":"${src}" "${target}"
|
||||||
|
}
|
||||||
|
|
||||||
|
get_rank_size()
|
||||||
|
{
|
||||||
|
local cluster_config=$1
|
||||||
|
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["rank_size"])'
|
||||||
|
}
|
||||||
|
|
||||||
|
get_train_dataset()
|
||||||
|
{
|
||||||
|
local cluster_config=$1
|
||||||
|
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["train_dataset"])'
|
||||||
|
}
|
||||||
|
|
||||||
|
get_cluster_list()
|
||||||
|
{
|
||||||
|
local cluster_config=$1
|
||||||
|
cat ${cluster_config} | python3 -c 'import sys,json;[print(node) for node in json.load(sys.stdin)["cluster"].keys()]' | sort
|
||||||
|
}
|
||||||
|
|
||||||
|
get_node_user()
|
||||||
|
{
|
||||||
|
local cluster_config=$1
|
||||||
|
local node=$2
|
||||||
|
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["user"])'
|
||||||
|
}
|
||||||
|
|
||||||
|
get_node_passwd()
|
||||||
|
{
|
||||||
|
local cluster_config=$1
|
||||||
|
local node=$2
|
||||||
|
cat ${cluster_config} | python3 -c 'import sys,json;print(json.load(sys.stdin)["cluster"]['\"${node}\"']["passwd"])'
|
||||||
|
}
|
||||||
|
|
||||||
|
rsync_sshpass()
|
||||||
|
{
|
||||||
|
local node=$1
|
||||||
|
local user="$2"
|
||||||
|
local passwd="$3"
|
||||||
|
scp_pass "${node}" "${user}" "${passwd}" /usr/local/bin/sshpass /usr/local/bin/sshpass
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
SCRIPTPATH="$( cd "$(dirname "$0")" || exit ; pwd -P )"
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source $SCRIPTPATH/common.sh
|
||||||
|
cluster_config_path=$1
|
||||||
|
execute_path=$2
|
||||||
|
RANK_SIZE=$(get_rank_size ${cluster_config_path})
|
||||||
|
RANK_START=0
|
||||||
|
node_list=$(get_cluster_list ${cluster_config_path})
|
||||||
|
|
||||||
|
for node in ${node_list}
|
||||||
|
do
|
||||||
|
user=$(get_node_user ${cluster_config_path} ${node})
|
||||||
|
passwd=$(get_node_passwd ${cluster_config_path} ${node})
|
||||||
|
echo "------------------${user}@${node}---------------------"
|
||||||
|
ssh_pass ${node} ${user} ${passwd} "rm -rf ${execute_path}"
|
||||||
|
scp_pass ${node} ${user} ${passwd} $SCRIPTPATH/../../wide_and_deep ${execute_path}
|
||||||
|
RANK_START=$[RANK_START+8]
|
||||||
|
if [[ $RANK_START -ge $RANK_SIZE ]]; then
|
||||||
|
break;
|
||||||
|
fi
|
||||||
|
done
|
|
@ -0,0 +1,48 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
execute_path=$(pwd)
|
||||||
|
echo ${execute_path}
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
self_path=$(dirname "${script_self}")
|
||||||
|
echo ${self_path}
|
||||||
|
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
RANK_START=$2
|
||||||
|
EPOCH_SIZE=$3
|
||||||
|
VOCAB_SIZE=$4
|
||||||
|
EMB_DIM=$5
|
||||||
|
DATASET=$6
|
||||||
|
ENV_SH=$7
|
||||||
|
MODE=$8
|
||||||
|
export MINDSPORE_HCCL_CONFIG=$9
|
||||||
|
export RANK_TABLE_FILE=$9
|
||||||
|
DEVICE_START=0
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source $ENV_SH
|
||||||
|
for((i=0;i<=7;i++));
|
||||||
|
do
|
||||||
|
export RANK_ID=$[i+RANK_START]
|
||||||
|
export DEVICE_ID=$[i+DEVICE_START]
|
||||||
|
rm -rf ${execute_path}/device_$RANK_ID
|
||||||
|
mkdir ${execute_path}/device_$RANK_ID
|
||||||
|
cd ${execute_path}/device_$RANK_ID || exit
|
||||||
|
if [ $MODE == "host_device_mix" ]; then
|
||||||
|
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 &
|
||||||
|
else
|
||||||
|
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 &
|
||||||
|
fi
|
||||||
|
done
|
|
@ -0,0 +1,51 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
execute_path=$(pwd)
|
||||||
|
echo ${execute_path}
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
SCRIPTPATH=$(dirname "${script_self}")
|
||||||
|
echo ${SCRIPTPATH}
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source $SCRIPTPATH/common.sh
|
||||||
|
cluster_config_path=$1
|
||||||
|
RANK_SIZE=$(get_rank_size ${cluster_config_path})
|
||||||
|
RANK_START=0
|
||||||
|
node_list=$(get_cluster_list ${cluster_config_path})
|
||||||
|
EPOCH_SIZE=$2
|
||||||
|
VOCAB_SIZE=$3
|
||||||
|
EMB_DIM=$4
|
||||||
|
DATASET=$5
|
||||||
|
MINDSPORE_HCCL_CONFIG_PATH=$6
|
||||||
|
ENV_SH=$7
|
||||||
|
MODE=$8
|
||||||
|
|
||||||
|
for node in ${node_list}
|
||||||
|
do
|
||||||
|
user=$(get_node_user ${cluster_config_path} ${node})
|
||||||
|
passwd=$(get_node_passwd ${cluster_config_path} ${node})
|
||||||
|
echo "------------------${user}@${node}---------------------"
|
||||||
|
if [ $MODE == "host_device_mix" ]; then
|
||||||
|
ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${MINDSPORE_HCCL_CONFIG_PATH}"
|
||||||
|
else
|
||||||
|
echo "[ERROR] mode is wrong"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
RANK_START=$[RANK_START+8]
|
||||||
|
if [[ $RANK_START -ge $RANK_SIZE ]]; then
|
||||||
|
break;
|
||||||
|
fi
|
||||||
|
done
|
|
@ -51,7 +51,7 @@ class LossCallBack(Callback):
|
||||||
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
|
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
|
||||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
cur_num = cb_params.cur_step_num
|
cur_num = cb_params.cur_step_num
|
||||||
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)
|
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
|
||||||
|
|
||||||
# raise ValueError
|
# raise ValueError
|
||||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
|
||||||
|
@ -76,7 +76,7 @@ class EvalCallBack(Callback):
|
||||||
Args:
|
Args:
|
||||||
print_per_step (int): Print loss every times. Default: 1.
|
print_per_step (int): Print loss every times. Default: 1.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
|
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False):
|
||||||
super(EvalCallBack, self).__init__()
|
super(EvalCallBack, self).__init__()
|
||||||
if not isinstance(print_per_step, int) or print_per_step < 0:
|
if not isinstance(print_per_step, int) or print_per_step < 0:
|
||||||
raise ValueError("print_per_step must be int and >= 0.")
|
raise ValueError("print_per_step must be int and >= 0.")
|
||||||
|
@ -87,6 +87,7 @@ class EvalCallBack(Callback):
|
||||||
self.aucMetric.clear()
|
self.aucMetric.clear()
|
||||||
self.eval_file_name = config.eval_file_name
|
self.eval_file_name = config.eval_file_name
|
||||||
self.eval_values = []
|
self.eval_values = []
|
||||||
|
self.host_device_mix = host_device_mix
|
||||||
|
|
||||||
def epoch_end(self, run_context):
|
def epoch_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -98,7 +99,7 @@ class EvalCallBack(Callback):
|
||||||
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
context.set_auto_parallel_context(strategy_ckpt_save_file="",
|
||||||
strategy_ckpt_load_file="./strategy_train.ckpt")
|
strategy_ckpt_load_file="./strategy_train.ckpt")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
out = self.model.eval(self.eval_dataset)
|
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
eval_time = int(end_time - start_time)
|
eval_time = int(end_time - start_time)
|
||||||
|
|
|
@ -38,6 +38,8 @@ def argparse_init():
|
||||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
|
||||||
parser.add_argument("--eval_file_name", type=str, default="eval.log")
|
parser.add_argument("--eval_file_name", type=str, default="eval.log")
|
||||||
parser.add_argument("--loss_file_name", type=str, default="loss.log")
|
parser.add_argument("--loss_file_name", type=str, default="loss.log")
|
||||||
|
parser.add_argument("--host_device_mix", type=int, default=0)
|
||||||
|
parser.add_argument("--dataset_type", type=str, default="tfrecord")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,6 +70,8 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = "eval.log"
|
self.eval_file_name = "eval.log"
|
||||||
self.loss_file_name = "loss.log"
|
self.loss_file_name = "loss.log"
|
||||||
self.ckpt_path = "./checkpoints/"
|
self.ckpt_path = "./checkpoints/"
|
||||||
|
self.host_device_mix = 0
|
||||||
|
self.dataset_type = "tfrecord"
|
||||||
|
|
||||||
def argparse_init(self):
|
def argparse_init(self):
|
||||||
"""
|
"""
|
||||||
|
@ -97,3 +101,5 @@ class WideDeepConfig():
|
||||||
self.eval_file_name = args.eval_file_name
|
self.eval_file_name = args.eval_file_name
|
||||||
self.loss_file_name = args.loss_file_name
|
self.loss_file_name = args.loss_file_name
|
||||||
self.ckpt_path = args.ckpt_path
|
self.ckpt_path = args.ckpt_path
|
||||||
|
self.host_device_mix = args.host_device_mix
|
||||||
|
self.dataset_type = args.dataset_type
|
|
@ -20,7 +20,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.nn import Dropout
|
from mindspore.nn import Dropout
|
||||||
from mindspore.nn.optim import Adam, FTRL
|
from mindspore.nn.optim import Adam, FTRL, LazyAdam
|
||||||
# from mindspore.nn.metrics import Metric
|
# from mindspore.nn.metrics import Metric
|
||||||
from mindspore.common.initializer import Uniform, initializer
|
from mindspore.common.initializer import Uniform, initializer
|
||||||
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
|
@ -82,7 +82,7 @@ class DenseLayer(nn.Cell):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
|
||||||
keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False):
|
keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
|
||||||
super(DenseLayer, self).__init__()
|
super(DenseLayer, self).__init__()
|
||||||
weight_init, bias_init = weight_bias_init
|
weight_init, bias_init = weight_bias_init
|
||||||
self.weight = init_method(
|
self.weight = init_method(
|
||||||
|
@ -137,8 +137,10 @@ class WideDeepModel(nn.Cell):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(WideDeepModel, self).__init__()
|
super(WideDeepModel, self).__init__()
|
||||||
self.batch_size = config.batch_size
|
self.batch_size = config.batch_size
|
||||||
|
host_device_mix = bool(config.host_device_mix)
|
||||||
parallel_mode = _get_parallel_mode()
|
parallel_mode = _get_parallel_mode()
|
||||||
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||||
|
if is_auto_parallel:
|
||||||
self.batch_size = self.batch_size * get_group_size()
|
self.batch_size = self.batch_size * get_group_size()
|
||||||
self.field_size = config.field_size
|
self.field_size = config.field_size
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
@ -187,16 +189,29 @@ class WideDeepModel(nn.Cell):
|
||||||
self.weight_bias_init,
|
self.weight_bias_init,
|
||||||
self.deep_layer_act,
|
self.deep_layer_act,
|
||||||
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
||||||
|
self.wide_mul = P.Mul()
|
||||||
self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
self.deep_mul = P.Mul()
|
||||||
self.mul = P.Mul()
|
|
||||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||||
self.reshape = P.Reshape()
|
self.reshape = P.Reshape()
|
||||||
|
self.deep_reshape = P.Reshape()
|
||||||
self.square = P.Square()
|
self.square = P.Square()
|
||||||
self.shape = P.Shape()
|
self.shape = P.Shape()
|
||||||
self.tile = P.Tile()
|
self.tile = P.Tile()
|
||||||
self.concat = P.Concat(axis=1)
|
self.concat = P.Concat(axis=1)
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
if is_auto_parallel and host_device_mix:
|
||||||
|
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
||||||
|
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
||||||
|
self.deep_embeddinglookup = nn.EmbeddingLookup()
|
||||||
|
self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
|
||||||
|
self.wide_embeddinglookup = nn.EmbeddingLookup()
|
||||||
|
self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
|
||||||
|
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
|
||||||
|
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
||||||
|
self.reduce_sum.add_prim_attr("cross_batch", True)
|
||||||
|
else:
|
||||||
|
self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
||||||
|
self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
||||||
|
|
||||||
def construct(self, id_hldr, wt_hldr):
|
def construct(self, id_hldr, wt_hldr):
|
||||||
"""
|
"""
|
||||||
|
@ -206,13 +221,13 @@ class WideDeepModel(nn.Cell):
|
||||||
"""
|
"""
|
||||||
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
||||||
# Wide layer
|
# Wide layer
|
||||||
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr)
|
wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr)
|
||||||
wx = self.mul(wide_id_weight, mask)
|
wx = self.wide_mul(wide_id_weight, mask)
|
||||||
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
||||||
# Deep layer
|
# Deep layer
|
||||||
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr)
|
deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr)
|
||||||
vx = self.mul(deep_id_embs, mask)
|
vx = self.deep_mul(deep_id_embs, mask)
|
||||||
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||||
deep_in = self.dense_layer_1(deep_in)
|
deep_in = self.dense_layer_1(deep_in)
|
||||||
deep_in = self.dense_layer_2(deep_in)
|
deep_in = self.dense_layer_2(deep_in)
|
||||||
deep_in = self.dense_layer_3(deep_in)
|
deep_in = self.dense_layer_3(deep_in)
|
||||||
|
@ -233,19 +248,28 @@ class NetWithLossClass(nn.Cell):
|
||||||
|
|
||||||
def __init__(self, network, config):
|
def __init__(self, network, config):
|
||||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||||
|
host_device_mix = bool(config.host_device_mix)
|
||||||
|
parallel_mode = _get_parallel_mode()
|
||||||
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||||
|
self.no_l2loss = host_device_mix and is_auto_parallel
|
||||||
self.network = network
|
self.network = network
|
||||||
self.l2_coef = config.l2_coef
|
self.l2_coef = config.l2_coef
|
||||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||||
self.square = P.Square()
|
self.square = P.Square()
|
||||||
self.reduceMean_false = P.ReduceMean(keep_dims=False)
|
self.reduceMean_false = P.ReduceMean(keep_dims=False)
|
||||||
|
if is_auto_parallel:
|
||||||
|
self.reduceMean_false.add_prim_attr("cross_batch", True)
|
||||||
self.reduceSum_false = P.ReduceSum(keep_dims=False)
|
self.reduceSum_false = P.ReduceSum(keep_dims=False)
|
||||||
|
|
||||||
def construct(self, batch_ids, batch_wts, label):
|
def construct(self, batch_ids, batch_wts, label):
|
||||||
predict, embedding_table = self.network(batch_ids, batch_wts)
|
predict, embedding_table = self.network(batch_ids, batch_wts)
|
||||||
log_loss = self.loss(predict, label)
|
log_loss = self.loss(predict, label)
|
||||||
wide_loss = self.reduceMean_false(log_loss)
|
wide_loss = self.reduceMean_false(log_loss)
|
||||||
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
|
if self.no_l2loss:
|
||||||
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
|
deep_loss = wide_loss
|
||||||
|
else:
|
||||||
|
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
|
||||||
|
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
|
||||||
|
|
||||||
return wide_loss, deep_loss
|
return wide_loss, deep_loss
|
||||||
|
|
||||||
|
@ -267,12 +291,15 @@ class TrainStepWrap(nn.Cell):
|
||||||
Append Adam and FTRL optimizers to the training network after that construct
|
Append Adam and FTRL optimizers to the training network after that construct
|
||||||
function can be called to create the backward graph.
|
function can be called to create the backward graph.
|
||||||
Args:
|
Args:
|
||||||
network (Cell): the training network. Note that loss function should have been added.
|
network (Cell): The training network. Note that loss function should have been added.
|
||||||
sens (Number): The adjust parameter. Default: 1000.0
|
sens (Number): The adjust parameter. Default: 1024.0
|
||||||
|
host_device_mix (Bool): Whether run in host and device mix mode. Default: False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, sens=1024.0):
|
def __init__(self, network, sens=1024.0, host_device_mix=False):
|
||||||
super(TrainStepWrap, self).__init__()
|
super(TrainStepWrap, self).__init__()
|
||||||
|
parallel_mode = _get_parallel_mode()
|
||||||
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||||
self.network = network
|
self.network = network
|
||||||
self.network.set_train()
|
self.network.set_train()
|
||||||
self.trainable_params = network.trainable_params()
|
self.trainable_params = network.trainable_params()
|
||||||
|
@ -285,10 +312,19 @@ class TrainStepWrap(nn.Cell):
|
||||||
weights_d.append(params)
|
weights_d.append(params)
|
||||||
self.weights_w = ParameterTuple(weights_w)
|
self.weights_w = ParameterTuple(weights_w)
|
||||||
self.weights_d = ParameterTuple(weights_d)
|
self.weights_d = ParameterTuple(weights_d)
|
||||||
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
|
|
||||||
l1=1e-8, l2=1e-8, initial_accum=1.0)
|
if host_device_mix and is_auto_parallel:
|
||||||
self.optimizer_d = Adam(
|
self.optimizer_d = LazyAdam(
|
||||||
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
||||||
|
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
||||||
|
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
|
||||||
|
self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||||
|
self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
||||||
|
else:
|
||||||
|
self.optimizer_d = Adam(
|
||||||
|
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
||||||
|
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
||||||
|
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
|
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
|
||||||
sens_param=True)
|
sens_param=True)
|
|
@ -17,7 +17,7 @@ from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack
|
from src.callbacks import LossCallBack
|
||||||
from src.datasets import create_dataset
|
from src.datasets import create_dataset, DataType
|
||||||
from src.config import WideDeepConfig
|
from src.config import WideDeepConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,14 @@ def test_train(configure):
|
||||||
data_path = configure.data_path
|
data_path = configure.data_path
|
||||||
batch_size = configure.batch_size
|
batch_size = configure.batch_size
|
||||||
epochs = configure.epochs
|
epochs = configure.epochs
|
||||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
|
if configure.dataset_type == "tfrecord":
|
||||||
|
dataset_type = DataType.TFRECORD
|
||||||
|
elif configure.dataset_type == "mindrecord":
|
||||||
|
dataset_type = DataType.MINDRECORD
|
||||||
|
else:
|
||||||
|
dataset_type = DataType.H5
|
||||||
|
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||||
|
batch_size=batch_size, data_type=dataset_type)
|
||||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
|
|
||||||
net_builder = ModelBuilder()
|
net_builder = ModelBuilder()
|
|
@ -19,7 +19,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
from src.datasets import create_dataset
|
from src.datasets import create_dataset, DataType
|
||||||
from src.metrics import AUCMetric
|
from src.metrics import AUCMetric
|
||||||
from src.config import WideDeepConfig
|
from src.config import WideDeepConfig
|
||||||
|
|
||||||
|
@ -67,8 +67,16 @@ def test_train_eval(config):
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
|
if config.dataset_type == "tfrecord":
|
||||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size)
|
dataset_type = DataType.TFRECORD
|
||||||
|
elif config.dataset_type == "mindrecord":
|
||||||
|
dataset_type = DataType.MINDRECORD
|
||||||
|
else:
|
||||||
|
dataset_type = DataType.H5
|
||||||
|
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||||
|
batch_size=batch_size, data_type=dataset_type)
|
||||||
|
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||||
|
batch_size=batch_size, data_type=dataset_type)
|
||||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
|
@ -27,13 +27,14 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
from src.datasets import create_dataset
|
from src.datasets import create_dataset, DataType
|
||||||
from src.metrics import AUCMetric
|
from src.metrics import AUCMetric
|
||||||
from src.config import WideDeepConfig
|
from src.config import WideDeepConfig
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
|
context.set_context(variable_memory_max_size="24GB")
|
||||||
|
context.set_context(enable_sparse=True)
|
||||||
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
cost_model_context.set_cost_model_context(multi_subgraphs=True)
|
||||||
init()
|
init()
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ def get_WideDeep_net(config):
|
||||||
WideDeep_net = WideDeepModel(config)
|
WideDeep_net = WideDeepModel(config)
|
||||||
loss_net = NetWithLossClass(WideDeep_net, config)
|
loss_net = NetWithLossClass(WideDeep_net, config)
|
||||||
loss_net = VirtualDatasetCellTriple(loss_net)
|
loss_net = VirtualDatasetCellTriple(loss_net)
|
||||||
train_net = TrainStepWrap(loss_net)
|
train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix))
|
||||||
eval_net = PredictWithSigmoid(WideDeep_net)
|
eval_net = PredictWithSigmoid(WideDeep_net)
|
||||||
eval_net = VirtualDatasetCellTriple(eval_net)
|
eval_net = VirtualDatasetCellTriple(eval_net)
|
||||||
return train_net, eval_net
|
return train_net, eval_net
|
||||||
|
@ -81,19 +82,28 @@ def train_and_eval(config):
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
|
if config.dataset_type == "tfrecord":
|
||||||
|
dataset_type = DataType.TFRECORD
|
||||||
|
elif config.dataset_type == "mindrecord":
|
||||||
|
dataset_type = DataType.MINDRECORD
|
||||||
|
else:
|
||||||
|
dataset_type = DataType.H5
|
||||||
|
host_device_mix = bool(config.host_device_mix)
|
||||||
print("epochs is {}".format(epochs))
|
print("epochs is {}".format(epochs))
|
||||||
if config.full_batch:
|
if config.full_batch:
|
||||||
context.set_auto_parallel_context(full_batch=True)
|
context.set_auto_parallel_context(full_batch=True)
|
||||||
de.config.set_seed(1)
|
de.config.set_seed(1)
|
||||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||||
batch_size=batch_size*get_group_size())
|
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||||
batch_size=batch_size*get_group_size())
|
batch_size=batch_size*get_group_size(), data_type=dataset_type)
|
||||||
else:
|
else:
|
||||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
batch_size=batch_size, rank_id=get_rank(),
|
||||||
|
rank_size=get_group_size(), data_type=dataset_type)
|
||||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
batch_size=batch_size, rank_id=get_rank(),
|
||||||
|
rank_size=get_group_size(), data_type=dataset_type)
|
||||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
||||||
|
@ -105,18 +115,24 @@ def train_and_eval(config):
|
||||||
|
|
||||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||||
|
|
||||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
|
||||||
|
|
||||||
callback = LossCallBack(config=config)
|
callback = LossCallBack(config=config)
|
||||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||||
directory=config.ckpt_path, config=ckptconfig)
|
directory=config.ckpt_path, config=ckptconfig)
|
||||||
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
|
context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt")
|
||||||
model.train(epochs, ds_train,
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||||
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
|
if not host_device_mix:
|
||||||
|
callback_list.append(ckpoint_cb)
|
||||||
|
model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
wide_deep_config = WideDeepConfig()
|
wide_deep_config = WideDeepConfig()
|
||||||
wide_deep_config.argparse_init()
|
wide_deep_config.argparse_init()
|
||||||
|
if wide_deep_config.host_device_mix == 1:
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||||
|
else:
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
|
||||||
train_and_eval(wide_deep_config)
|
train_and_eval(wide_deep_config)
|
|
@ -25,7 +25,7 @@ from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
from src.datasets import create_dataset
|
from src.datasets import create_dataset, DataType
|
||||||
from src.metrics import AUCMetric
|
from src.metrics import AUCMetric
|
||||||
from src.config import WideDeepConfig
|
from src.config import WideDeepConfig
|
||||||
|
|
||||||
|
@ -73,11 +73,19 @@ def train_and_eval(config):
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
|
if config.dataset_type == "tfrecord":
|
||||||
|
dataset_type = DataType.TFRECORD
|
||||||
|
elif config.dataset_type == "mindrecord":
|
||||||
|
dataset_type = DataType.MINDRECORD
|
||||||
|
else:
|
||||||
|
dataset_type = DataType.H5
|
||||||
print("epochs is {}".format(epochs))
|
print("epochs is {}".format(epochs))
|
||||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
batch_size=batch_size, rank_id=get_rank(),
|
||||||
|
rank_size=get_group_size(), data_type=dataset_type)
|
||||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||||
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
|
batch_size=batch_size, rank_id=get_rank(),
|
||||||
|
rank_size=get_group_size(), data_type=dataset_type)
|
||||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||||
|
|
|
@ -21,10 +21,10 @@ export RANK_SIZE=$DEVICE_NUM
|
||||||
unset SLOG_PRINT_TO_STDOUT
|
unset SLOG_PRINT_TO_STDOUT
|
||||||
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
|
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
|
||||||
CODE_DIR="./"
|
CODE_DIR="./"
|
||||||
if [ -d ${BASE_PATH}/../../../../model_zoo/wide_and_deep ]; then
|
if [ -d ${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep ]; then
|
||||||
CODE_DIR=${BASE_PATH}/../../../../model_zoo/wide_and_deep
|
CODE_DIR=${BASE_PATH}/../../../../model_zoo/official/recommend/wide_and_deep
|
||||||
elif [ -d ${BASE_PATH}/../../model_zoo/wide_and_deep ]; then
|
elif [ -d ${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep ]; then
|
||||||
CODE_DIR=${BASE_PATH}/../../model_zoo/wide_and_deep
|
CODE_DIR=${BASE_PATH}/../../model_zoo/official/recommend/wide_and_deep
|
||||||
else
|
else
|
||||||
echo "[ERROR] code dir is not found"
|
echo "[ERROR] code dir is not found"
|
||||||
fi
|
fi
|
||||||
|
|
Loading…
Reference in New Issue