forked from mindspore-Ecosystem/mindspore
modelzoo wide_and_deep_multitable
This commit is contained in:
parent
68ba6532c4
commit
245415f5bd
|
@ -69,8 +69,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
const size_t thread_num = 8;
|
||||
std::thread threads[8];
|
||||
const size_t thread_num = 16;
|
||||
std::thread threads[16];
|
||||
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
|
||||
size_t i;
|
||||
size_t task_offset = 0;
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
"""train_dataset."""
|
||||
|
||||
|
||||
import os
|
||||
|
|
|
@ -164,9 +164,6 @@ class WideDeepModel(nn.Cell):
|
|||
init_acts = [('Wide_b', [1], self.emb_init)]
|
||||
var_map = init_var_dict(self.init_args, init_acts)
|
||||
self.wide_b = var_map["Wide_b"]
|
||||
if parameter_server:
|
||||
self.wide_w.set_param_ps()
|
||||
self.embedding_table.set_param_ps()
|
||||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
|
||||
self.all_dim_list[1],
|
||||
self.weight_bias_init,
|
||||
|
@ -217,6 +214,8 @@ class WideDeepModel(nn.Cell):
|
|||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
|
||||
self.embedding_table = self.deep_embeddinglookup.embedding_table
|
||||
self.wide_w.set_param_ps()
|
||||
self.embedding_table.set_param_ps()
|
||||
else:
|
||||
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
|
||||
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
numpy
|
||||
pandas
|
||||
pickle
|
|
@ -0,0 +1,34 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# bash run_multinpu_train.sh
|
||||
execute_path=$(pwd)
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export RANK_SIZE=$1
|
||||
export EPOCH_SIZE=$2
|
||||
export DATASET=$3
|
||||
export RANK_TABLE_FILE=$4
|
||||
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm -rf ${execute_path}/device_$i/
|
||||
mkdir ${execute_path}/device_$i/
|
||||
cd ${execute_path}/device_$i/ || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
python -s ${self_path}/../train_and_eval_distribute.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
|
||||
done
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
callbacks
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
with open(file_path, 'a+', encoding="utf-8") as file_out:
|
||||
file_out.write(out_str + "\n")
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, config, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.config = config
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
cb_params = run_context.original_args()
|
||||
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), \
|
||||
cb_params.net_outputs[1].asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
loss_file = open(self.config.loss_file_name, "a+")
|
||||
loss_file.write(
|
||||
"epoch: %s step: %s, wide_loss is %s, deep_loss is %s" %
|
||||
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss,
|
||||
deep_loss))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
print("epoch: %s step: %s, wide_loss is %s, deep_loss is %s" % (
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss,
|
||||
deep_loss))
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
|
||||
super(EvalCallBack, self).__init__()
|
||||
if not isinstance(print_per_step, int) or print_per_step < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.print_per_step = print_per_step
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_name = config.eval_file_name
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Monitor the auc in training."""
|
||||
self.aucMetric.clear()
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset)
|
||||
end_time = time.time()
|
||||
eval_time = int(end_time - start_time)
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = "{}=====EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time)
|
||||
print(out_str)
|
||||
add_write(self.eval_file_name, out_str)
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" config. """
|
||||
import argparse
|
||||
|
||||
|
||||
def argparse_init():
|
||||
"""
|
||||
argparse_init
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='WideDeep')
|
||||
|
||||
parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data.
|
||||
parser.add_argument("--epochs", type=int, default=200) # The number of epochs used to train.
|
||||
parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation
|
||||
parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation.
|
||||
parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP
|
||||
parser.add_argument("--deep_layers_act", type=str, default='relu') # The act of hidden layers for MLP
|
||||
parser.add_argument("--keep_prob", type=float, default=1.0) # The Embedding size of MF model.
|
||||
parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr
|
||||
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
|
||||
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
|
||||
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
|
||||
|
||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
|
||||
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
|
||||
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
|
||||
return parser
|
||||
|
||||
|
||||
class WideDeepConfig():
|
||||
"""
|
||||
WideDeepConfig
|
||||
"""
|
||||
def __init__(self):
|
||||
self.data_path = ''
|
||||
self.epochs = 200
|
||||
self.batch_size = 131072
|
||||
self.eval_batch_size = 131072
|
||||
self.deep_layers_act = 'relu'
|
||||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = False
|
||||
self.keep_prob = 1.0
|
||||
self.l2_coef = 0.0
|
||||
|
||||
self.adam_lr = 0.003
|
||||
|
||||
self.ftrl_lr = 0.1
|
||||
|
||||
self.is_tf_dataset = True
|
||||
self.input_emb_dim = 0
|
||||
self.output_path = "./output/"
|
||||
self.eval_file_name = "eval.log"
|
||||
self.loss_file_name = "loss.log"
|
||||
self.ckpt_path = "./checkpoints/"
|
||||
|
||||
def argparse_init(self):
|
||||
"""
|
||||
argparse_init
|
||||
"""
|
||||
parser = argparse_init()
|
||||
args, _ = parser.parse_known_args()
|
||||
self.data_path = args.data_path
|
||||
self.epochs = args.epochs
|
||||
self.batch_size = args.batch_size
|
||||
self.eval_batch_size = args.eval_batch_size
|
||||
self.deep_layers_act = args.deep_layers_act
|
||||
self.keep_prob = args.keep_prob
|
||||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = False
|
||||
self.l2_coef = args.l2_coef
|
||||
self.ftrl_lr = args.ftrl_lr
|
||||
self.adam_lr = args.adam_lr
|
||||
self.is_tf_dataset = args.is_tf_dataset
|
||||
|
||||
self.output_path = args.output_path
|
||||
self.eval_file_name = args.eval_file_name
|
||||
self.loss_file_name = args.loss_file_name
|
||||
self.ckpt_path = args.ckpt_path
|
|
@ -0,0 +1,341 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_dataset."""
|
||||
import os
|
||||
import math
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class H5Dataset():
|
||||
"""
|
||||
H5Dataset
|
||||
"""
|
||||
input_length = 39
|
||||
|
||||
def __init__(self,
|
||||
data_path,
|
||||
train_mode=True,
|
||||
train_num_of_parts=21,
|
||||
test_num_of_parts=3):
|
||||
self._hdf_data_dir = data_path
|
||||
self._is_training = train_mode
|
||||
|
||||
if self._is_training:
|
||||
self._file_prefix = 'train'
|
||||
self._num_of_parts = train_num_of_parts
|
||||
else:
|
||||
self._file_prefix = 'test'
|
||||
self._num_of_parts = test_num_of_parts
|
||||
|
||||
self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix,
|
||||
self._num_of_parts)
|
||||
print("data_size: {}".format(self.data_size))
|
||||
|
||||
def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts):
|
||||
size = 0
|
||||
for part in range(num_of_parts):
|
||||
_y = pd.read_hdf(
|
||||
os.path.join(hdf_data_dir, file_prefix + '_output_part_' +
|
||||
str(part) + '.h5'))
|
||||
size += _y.shape[0]
|
||||
return size
|
||||
|
||||
def _iterate_hdf_files_(self, num_of_parts=None, shuffle_block=False):
|
||||
"""
|
||||
iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts
|
||||
from the beginning, thus the data stream will never stop
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
this file iterator will go through the train set
|
||||
:param num_of_parts: number of files
|
||||
:param shuffle_block: shuffle block files at every round
|
||||
:return: input_hdf_file_name, output_hdf_file_name, finish_flag
|
||||
"""
|
||||
parts = np.arange(num_of_parts)
|
||||
while True:
|
||||
if shuffle_block:
|
||||
for _ in range(int(shuffle_block)):
|
||||
np.random.shuffle(parts)
|
||||
for i, p in enumerate(parts):
|
||||
yield os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_input_part_' + str(
|
||||
p) + '.h5'), \
|
||||
os.path.join(self._hdf_data_dir,
|
||||
self._file_prefix + '_output_part_' + str(
|
||||
p) + '.h5'), i + 1 == len(parts)
|
||||
|
||||
def _generator(self, X, y, batch_size, shuffle=True):
|
||||
"""
|
||||
should be accessed only in private
|
||||
:param X:
|
||||
:param y:
|
||||
:param batch_size:
|
||||
:param shuffle:
|
||||
:return:
|
||||
"""
|
||||
number_of_batches = np.ceil(1. * X.shape[0] / batch_size)
|
||||
counter = 0
|
||||
finished = False
|
||||
sample_index = np.arange(X.shape[0])
|
||||
if shuffle:
|
||||
for _ in range(int(shuffle)):
|
||||
np.random.shuffle(sample_index)
|
||||
assert X.shape[0] > 0
|
||||
while True:
|
||||
batch_index = sample_index[batch_size * counter:batch_size *
|
||||
(counter + 1)]
|
||||
X_batch = X[batch_index]
|
||||
y_batch = y[batch_index]
|
||||
counter += 1
|
||||
yield X_batch, y_batch, finished
|
||||
if counter == number_of_batches:
|
||||
counter = 0
|
||||
finished = True
|
||||
|
||||
def batch_generator(self,
|
||||
batch_size=1000,
|
||||
random_sample=False,
|
||||
shuffle_block=False):
|
||||
"""
|
||||
:param train_mode: True or false,false is eval_mode,
|
||||
:param batch_size
|
||||
:param num_of_parts: number of files
|
||||
:param random_sample: if True, will shuffle
|
||||
:param shuffle_block: shuffle file blocks at every round
|
||||
:return:
|
||||
"""
|
||||
|
||||
for hdf_in, hdf_out, _ in self._iterate_hdf_files_(
|
||||
self._num_of_parts, shuffle_block):
|
||||
start = stop = None
|
||||
X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values
|
||||
y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values
|
||||
data_gen = self._generator(X_all,
|
||||
y_all,
|
||||
batch_size,
|
||||
shuffle=random_sample)
|
||||
finished = False
|
||||
|
||||
while not finished:
|
||||
X, y, finished = data_gen.__next__()
|
||||
X_id = X[:, 0:self.input_length]
|
||||
X_va = X[:, self.input_length:]
|
||||
yield np.array(X_id.astype(dtype=np.int32)), np.array(
|
||||
X_va.astype(dtype=np.float32)), np.array(
|
||||
y.astype(dtype=np.float32))
|
||||
|
||||
|
||||
def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
|
||||
"""
|
||||
_get_h5_dataset
|
||||
"""
|
||||
data_para = {
|
||||
'batch_size': batch_size,
|
||||
}
|
||||
if train_mode:
|
||||
data_para['random_sample'] = True
|
||||
data_para['shuffle_block'] = True
|
||||
|
||||
h5_dataset = H5Dataset(data_path=data_dir, train_mode=train_mode)
|
||||
numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size)
|
||||
|
||||
def _iter_h5_data():
|
||||
train_eval_gen = h5_dataset.batch_generator(**data_para)
|
||||
for _ in range(0, numbers_of_batch, 1):
|
||||
yield train_eval_gen.__next__()
|
||||
|
||||
ds = de.GeneratorDataset(_iter_h5_data(),
|
||||
["ids", "weights", "labels"])
|
||||
ds.set_dataset_size(numbers_of_batch)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def _get_tf_dataset(data_dir,
|
||||
schema_dict,
|
||||
input_shape_dict,
|
||||
train_mode=True,
|
||||
epochs=1,
|
||||
batch_size=4096,
|
||||
line_per_sample=4096,
|
||||
rank_size=None,
|
||||
rank_id=None):
|
||||
"""
|
||||
_get_tf_dataset
|
||||
"""
|
||||
dataset_files = []
|
||||
file_prefix_name = 'train' if train_mode else 'eval'
|
||||
shuffle = bool(train_mode)
|
||||
for (dirpath, _, filenames) in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if file_prefix_name in filename and "tfrecord" in filename:
|
||||
dataset_files.append(os.path.join(dirpath, filename))
|
||||
schema = de.Schema()
|
||||
|
||||
float_key_list = ["label", "continue_val"]
|
||||
|
||||
columns_list = []
|
||||
for key, attr_dict in schema_dict.items():
|
||||
print("key: {}; shape: {}".format(key, attr_dict["tf_shape"]))
|
||||
columns_list.append(key)
|
||||
if key in set(float_key_list):
|
||||
ms_dtype = mstype.float32
|
||||
else:
|
||||
ms_dtype = mstype.int32
|
||||
schema.add_column(key, de_type=ms_dtype)
|
||||
|
||||
if rank_size is not None and rank_id is not None:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files,
|
||||
shuffle=shuffle,
|
||||
schema=schema,
|
||||
num_parallel_workers=8,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
shard_equal_rows=True)
|
||||
else:
|
||||
ds = de.TFRecordDataset(dataset_files=dataset_files,
|
||||
shuffle=shuffle,
|
||||
schema=schema,
|
||||
num_parallel_workers=8)
|
||||
ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True)
|
||||
|
||||
operations_list = []
|
||||
for key in columns_list:
|
||||
operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key]))
|
||||
print("ssssssssssssssssssssss---------------------" * 10)
|
||||
print(input_shape_dict)
|
||||
print("---------------------" * 10)
|
||||
print(schema_dict)
|
||||
|
||||
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
|
||||
a = np.asarray(a.reshape(batch_size,))
|
||||
b = np.array(b).flatten().reshape(batch_size, -1)
|
||||
c = np.array(c).flatten().reshape(batch_size, -1)
|
||||
d = np.array(d).flatten().reshape(batch_size, -1)
|
||||
e = np.array(e).flatten().reshape(batch_size, -1)
|
||||
|
||||
f = np.array(f).flatten().reshape(batch_size, -1)
|
||||
g = np.array(g).flatten().reshape(batch_size, -1)
|
||||
h = np.array(h).flatten().reshape(batch_size, -1)
|
||||
i = np.array(i).flatten().reshape(batch_size, -1)
|
||||
j = np.array(j).flatten().reshape(batch_size, -1)
|
||||
|
||||
k = np.array(k).flatten().reshape(batch_size, -1)
|
||||
l = np.array(l).flatten().reshape(batch_size, -1)
|
||||
m = np.array(m).flatten().reshape(batch_size, -1)
|
||||
n = np.array(n).flatten().reshape(batch_size, -1)
|
||||
o = np.array(o).flatten().reshape(batch_size, -1)
|
||||
|
||||
p = np.array(p).flatten().reshape(batch_size, -1)
|
||||
q = np.array(q).flatten().reshape(batch_size, -1)
|
||||
r = np.array(r).flatten().reshape(batch_size, -1)
|
||||
s = np.array(s).flatten().reshape(batch_size, -1)
|
||||
t = np.array(t).flatten().reshape(batch_size, -1)
|
||||
|
||||
u = np.array(u).flatten().reshape(batch_size, -1)
|
||||
return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u
|
||||
|
||||
ds = ds.map(
|
||||
operations=mixup,
|
||||
input_columns=[
|
||||
'label', 'continue_val', 'indicator_id', 'emb_128_id',
|
||||
'emb_64_single_id', 'multi_doc_ad_category_id',
|
||||
'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id',
|
||||
'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id',
|
||||
'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id',
|
||||
'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id',
|
||||
'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id',
|
||||
'multi_doc_ad_topic_id_mask', 'ad_id', 'display_ad_and_is_leak',
|
||||
'display_id', 'is_leak'
|
||||
],
|
||||
columns_order=[
|
||||
'label', 'continue_val', 'indicator_id', 'emb_128_id',
|
||||
'emb_64_single_id', 'multi_doc_ad_category_id',
|
||||
'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id',
|
||||
'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id',
|
||||
'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id',
|
||||
'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id',
|
||||
'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id',
|
||||
'multi_doc_ad_topic_id_mask', 'display_id', 'ad_id',
|
||||
'display_ad_and_is_leak', 'is_leak'
|
||||
],
|
||||
num_parallel_workers=8)
|
||||
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
||||
|
||||
def compute_emb_dim(config):
|
||||
"""
|
||||
compute_emb_dim
|
||||
"""
|
||||
with open(
|
||||
os.path.join(config.data_path + 'dataformat/',
|
||||
"input_shape_dict.pkl"), "rb") as file_in:
|
||||
input_shape_dict = pickle.load(file_in)
|
||||
input_field_size = {}
|
||||
for key, shape in input_shape_dict.items():
|
||||
if len(shape) < 2:
|
||||
input_field_size[key] = 1
|
||||
else:
|
||||
input_field_size[key] = shape[1]
|
||||
multi_key_list = [
|
||||
"multi_doc_event_topic_id", "multi_doc_event_entity_id",
|
||||
"multi_doc_ad_category_id", "multi_doc_event_category_id",
|
||||
"multi_doc_ad_entity_id", "multi_doc_ad_topic_id"
|
||||
]
|
||||
|
||||
config.input_emb_dim = input_field_size["continue_val"] + \
|
||||
input_field_size["indicator_id"] * 64 + \
|
||||
input_field_size["emb_128_id"] * 128 + \
|
||||
input_field_size["emb_64_single_id"] * 64 + \
|
||||
len(multi_key_list) * 64
|
||||
|
||||
|
||||
def create_dataset(data_dir,
|
||||
train_mode=True,
|
||||
epochs=1,
|
||||
batch_size=4096,
|
||||
is_tf_dataset=True,
|
||||
line_per_sample=4096,
|
||||
rank_size=None,
|
||||
rank_id=None):
|
||||
"""
|
||||
create_dataset
|
||||
"""
|
||||
if is_tf_dataset:
|
||||
with open(os.path.join(data_dir + 'dataformat/', "schema_dict.pkl"),
|
||||
"rb") as file_in:
|
||||
print(os.path.join(data_dir + 'dataformat/', "schema_dict.pkl"))
|
||||
schema_dict = pickle.load(file_in)
|
||||
with open(
|
||||
os.path.join(data_dir + 'dataformat/', "input_shape_dict.pkl"),
|
||||
"rb") as file_in:
|
||||
input_shape_dict = pickle.load(file_in)
|
||||
return _get_tf_dataset(data_dir,
|
||||
schema_dict,
|
||||
input_shape_dict,
|
||||
train_mode,
|
||||
epochs,
|
||||
batch_size,
|
||||
line_per_sample,
|
||||
rank_size=rank_size,
|
||||
rank_id=rank_id)
|
||||
if rank_size is not None and rank_size > 1:
|
||||
raise RuntimeError("please use tfrecord dataset.")
|
||||
return _get_h5_dataset(data_dir, train_mode, epochs, batch_size)
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Area under cure metric
|
||||
"""
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import roc_auc_score, average_precision_score
|
||||
from mindspore.nn.metrics import Metric
|
||||
|
||||
|
||||
def groupby_df_v1(test_df, gb_key):
|
||||
"""
|
||||
groupby_df_v1
|
||||
"""
|
||||
data_groups = test_df.groupby(gb_key)
|
||||
return data_groups
|
||||
|
||||
|
||||
def _compute_metric_v1(batch_groups, topk):
|
||||
"""
|
||||
_compute_metric_v1
|
||||
"""
|
||||
results = []
|
||||
for df in batch_groups:
|
||||
df = df.sort_values(by="preds", ascending=False)
|
||||
if df.shape[0] > topk:
|
||||
df = df.head(topk)
|
||||
preds = df["preds"].values
|
||||
labels = df["labels"].values
|
||||
if np.sum(labels) > 0:
|
||||
results.append(average_precision_score(labels, preds))
|
||||
else:
|
||||
results.append(0.0)
|
||||
return results
|
||||
|
||||
|
||||
def mean_AP_topk(batch_labels, batch_preds, topk=12):
|
||||
"""
|
||||
mean_AP_topk
|
||||
"""
|
||||
def ap_score(label, y_preds, topk):
|
||||
ind_list = np.argsort(y_preds)[::-1]
|
||||
ind_list = ind_list[:topk]
|
||||
if label not in set(ind_list):
|
||||
return 0.0
|
||||
rank = list(ind_list).index(label)
|
||||
return 1.0 / (rank + 1)
|
||||
|
||||
mAP_list = []
|
||||
for label, preds in zip(batch_labels, batch_preds):
|
||||
mAP = ap_score(label, preds, topk)
|
||||
mAP_list.append(mAP)
|
||||
return mAP_list
|
||||
|
||||
|
||||
def new_compute_mAP(test_df, gb_key="display_ids", top_k=12):
|
||||
"""
|
||||
new_compute_mAP
|
||||
"""
|
||||
total_start = time.time()
|
||||
display_ids = test_df["display_ids"]
|
||||
labels = test_df["labels"]
|
||||
predictions = test_df["preds"]
|
||||
|
||||
test_df.sort_values(by=[gb_key], inplace=True, ascending=True)
|
||||
display_ids = test_df["display_ids"]
|
||||
labels = test_df["labels"]
|
||||
predictions = test_df["preds"]
|
||||
|
||||
_, display_ids_idx = np.unique(display_ids, return_index=True)
|
||||
|
||||
preds = np.split(predictions.tolist(), display_ids_idx.tolist()[1:])
|
||||
labels = np.split(labels.tolist(), display_ids_idx.tolist()[1:])
|
||||
|
||||
def pad_fn(ele_l):
|
||||
res_list = ele_l + [0.0 for i in range(30 - len(ele_l))]
|
||||
return res_list
|
||||
|
||||
preds = list(map(lambda x: pad_fn(x.tolist()), preds))
|
||||
labels = [np.argmax(l) for l in labels]
|
||||
|
||||
result_list = []
|
||||
|
||||
batch_size = 100000
|
||||
for idx in range(0, len(labels), batch_size):
|
||||
batch_labels = labels[idx:idx + batch_size]
|
||||
batch_preds = preds[idx:idx + batch_size]
|
||||
meanAP = mean_AP_topk(batch_labels, batch_preds, topk=top_k)
|
||||
result_list.extend(meanAP)
|
||||
mean_AP = np.mean(result_list)
|
||||
print("compute time: {}".format(time.time() - total_start))
|
||||
print("mean_AP: {}".format(mean_AP))
|
||||
return mean_AP
|
||||
|
||||
|
||||
class AUCMetric(Metric):
|
||||
"""
|
||||
AUCMetric
|
||||
"""
|
||||
def __init__(self):
|
||||
super(AUCMetric, self).__init__()
|
||||
self.index = 1
|
||||
|
||||
def clear(self):
|
||||
"""Clear the internal evaluation result."""
|
||||
self.true_labels = []
|
||||
self.pred_probs = []
|
||||
self.display_id = []
|
||||
|
||||
def update(self, *inputs):
|
||||
"""
|
||||
update
|
||||
"""
|
||||
all_predict = inputs[1].asnumpy() # predict
|
||||
all_label = inputs[2].asnumpy() # label
|
||||
all_display_id = inputs[3].asnumpy() # label
|
||||
self.true_labels.extend(all_label.flatten().tolist())
|
||||
self.pred_probs.extend(all_predict.flatten().tolist())
|
||||
self.display_id.extend(all_display_id.flatten().tolist())
|
||||
|
||||
def eval(self):
|
||||
"""
|
||||
eval
|
||||
"""
|
||||
if len(self.true_labels) != len(self.pred_probs):
|
||||
raise RuntimeError(
|
||||
'true_labels.size() is not equal to pred_probs.size()')
|
||||
|
||||
result_df = pd.DataFrame({
|
||||
"display_ids": self.display_id,
|
||||
"preds": self.pred_probs,
|
||||
"labels": self.true_labels,
|
||||
})
|
||||
auc = roc_auc_score(self.true_labels, self.pred_probs)
|
||||
|
||||
MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12)
|
||||
print("=====" * 20 + " auc_metric end ")
|
||||
print("=====" * 20 + " auc: {}, map: {}".format(auc, MAP))
|
||||
return auc
|
|
@ -0,0 +1,638 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""wide and deep model"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Dropout, Flatten
|
||||
from mindspore.nn.optim import Adam, FTRL
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
|
||||
np_type = np.float32
|
||||
ms_type = mstype.float32
|
||||
|
||||
|
||||
def init_method(method, shape, name, max_val=1.0):
|
||||
"""
|
||||
Init method
|
||||
"""
|
||||
if method in ['uniform']:
|
||||
params = Parameter(initializer(Uniform(max_val), shape, ms_type),
|
||||
name=name)
|
||||
elif method == "one":
|
||||
params = Parameter(initializer("ones", shape, ms_type), name=name)
|
||||
elif method == 'zero':
|
||||
params = Parameter(initializer("zeros", shape, ms_type), name=name)
|
||||
elif method == "normal":
|
||||
params = Parameter(Tensor(
|
||||
np.random.normal(loc=0.0, scale=0.01,
|
||||
size=shape).astype(dtype=np_type)),
|
||||
name=name)
|
||||
return params
|
||||
|
||||
|
||||
def init_var_dict(init_args, in_vars):
|
||||
"""
|
||||
Init parameters by dict
|
||||
"""
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for _, iterm in enumerate(in_vars):
|
||||
key, shape, method = iterm
|
||||
if key not in var_map.keys():
|
||||
if method in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(Uniform(_max_val), shape,
|
||||
ms_type),
|
||||
name=key)
|
||||
elif method == "one":
|
||||
var_map[key] = Parameter(initializer("ones", shape, ms_type),
|
||||
name=key)
|
||||
elif method == "zero":
|
||||
var_map[key] = Parameter(initializer("zeros", shape, ms_type),
|
||||
name=key)
|
||||
elif method == 'normal':
|
||||
var_map[key] = Parameter(Tensor(
|
||||
np.random.normal(loc=0.0, scale=0.01,
|
||||
size=shape).astype(dtype=np_type)),
|
||||
name=key)
|
||||
return var_map
|
||||
|
||||
|
||||
class DenseLayer(nn.Cell):
|
||||
"""
|
||||
Dense Layer for Deep Layer of WideDeep Model;
|
||||
Containing: activation, matmul, bias_add;
|
||||
"""
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
weight_bias_init,
|
||||
act_str,
|
||||
keep_prob=0.7,
|
||||
scale_coef=1.0,
|
||||
convert_dtype=True):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(weight_init, [input_dim, output_dim],
|
||||
name="weight")
|
||||
self.bias = init_method(bias_init, [output_dim], name="bias")
|
||||
self.act_func = self._init_activation(act_str)
|
||||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=0.8)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
self.convert_dtype = convert_dtype
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_str = act_str.lower()
|
||||
if act_str == "relu":
|
||||
act_func = P.ReLU()
|
||||
elif act_str == "sigmoid":
|
||||
act_func = P.Sigmoid()
|
||||
elif act_str == "tanh":
|
||||
act_func = P.Tanh()
|
||||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
DenseLayer construct
|
||||
"""
|
||||
x = self.act_func(x)
|
||||
if self.training:
|
||||
x = self.dropout(x)
|
||||
x = self.mul(x, self.scale_coef)
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
else:
|
||||
wx = self.matmul(x, self.weight)
|
||||
wx = self.realDiv(wx, self.scale_coef)
|
||||
output = self.bias_add(wx, self.bias)
|
||||
return output
|
||||
|
||||
|
||||
class WideDeepModel(nn.Cell):
|
||||
"""
|
||||
From paper: " Wide & Deep Learning for Recommender Systems"
|
||||
Args:
|
||||
config (Class): The default config of Wide&Deep
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(WideDeepModel, self).__init__()
|
||||
emb_128_size = 650000
|
||||
emb64_single_size = 17300
|
||||
emb64_multi_size = 20900
|
||||
indicator_size = 16
|
||||
deep_dim_list = [1024, 1024, 1024, 1024, 1024]
|
||||
# deep_dropout=0.0
|
||||
wide_reg_coef = [0.0, 0.0]
|
||||
deep_reg_coef = [0.0, 0.0]
|
||||
wide_lr = 0.2
|
||||
deep_lr = 1.0
|
||||
|
||||
self.input_emb_dim = config.input_emb_dim
|
||||
self.batch_size = config.batch_size
|
||||
self.deep_layer_act = config.deep_layers_act
|
||||
self.init_args = config.init_args
|
||||
self.weight_init, self.bias_init = config.weight_bias_init
|
||||
self.weight_bias_init = config.weight_bias_init
|
||||
self.emb_init = config.emb_init
|
||||
|
||||
self.keep_prob = config.keep_prob
|
||||
self.layer_dims = deep_dim_list + [1]
|
||||
self.all_dim_list = [self.input_emb_dim] + self.layer_dims
|
||||
|
||||
self.continue_field_size = 32
|
||||
self.emb_128_size = emb_128_size
|
||||
self.emb64_single_size = emb64_single_size
|
||||
self.emb64_multi_size = emb64_multi_size
|
||||
self.indicator_size = indicator_size
|
||||
|
||||
self.wide_l1_coef, self.wide_l2_coef = wide_reg_coef
|
||||
self.deep_l1_coef, self.deep_l2_coef = deep_reg_coef
|
||||
self.wide_lr = wide_lr
|
||||
self.deep_lr = deep_lr
|
||||
|
||||
init_acts_embedding_metrix = [
|
||||
('emb128_embedding', [self.emb_128_size, 128], self.emb_init),
|
||||
('emb64_single', [self.emb64_single_size, 64], self.emb_init),
|
||||
('emb64_multi', [self.emb64_multi_size, 64], self.emb_init),
|
||||
('emb64_indicator', [self.indicator_size, 64], self.emb_init)
|
||||
]
|
||||
var_map = init_var_dict(self.init_args, init_acts_embedding_metrix)
|
||||
self.emb128_embedding = var_map["emb128_embedding"]
|
||||
self.emb64_single = var_map["emb64_single"]
|
||||
self.emb64_multi = var_map["emb64_multi"]
|
||||
self.emb64_indicator = var_map["emb64_indicator"]
|
||||
|
||||
init_acts_wide_weight = [
|
||||
('wide_continue_w', [self.continue_field_size], self.emb_init),
|
||||
('wide_emb128_w', [self.emb_128_size], self.emb_init),
|
||||
('wide_emb64_single_w', [self.emb64_single_size], self.emb_init),
|
||||
('wide_emb64_multi_w', [self.emb64_multi_size], self.emb_init),
|
||||
('wide_indicator_w', [self.indicator_size], self.emb_init),
|
||||
('wide_bias', [1], self.emb_init)
|
||||
]
|
||||
var_map = init_var_dict(self.init_args, init_acts_wide_weight)
|
||||
self.wide_continue_w = var_map["wide_continue_w"]
|
||||
self.wide_emb128_w = var_map["wide_emb128_w"]
|
||||
self.wide_emb64_single_w = var_map["wide_emb64_single_w"]
|
||||
self.wide_emb64_multi_w = var_map["wide_emb64_multi_w"]
|
||||
self.wide_indicator_w = var_map["wide_indicator_w"]
|
||||
self.wide_bias = var_map["wide_bias"]
|
||||
|
||||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
|
||||
self.all_dim_list[1],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
|
||||
self.all_dim_list[2],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
|
||||
self.all_dim_list[3],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
|
||||
self.all_dim_list[4],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
|
||||
self.all_dim_list[5],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
|
||||
self.deep_predict = DenseLayer(self.all_dim_list[5],
|
||||
self.all_dim_list[6],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.mul = P.Mul()
|
||||
self.reduce_sum_false = P.ReduceSum(keep_dims=False)
|
||||
self.reduce_sum_true = P.ReduceSum(keep_dims=True)
|
||||
self.reshape = P.Reshape()
|
||||
self.square = P.Square()
|
||||
self.shape = P.Shape()
|
||||
self.tile = P.Tile()
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.cast = P.Cast()
|
||||
self.reduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.Concat = P.Concat(axis=1)
|
||||
self.BiasAdd = P.BiasAdd()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.flatten = Flatten()
|
||||
|
||||
def construct(self, continue_val, indicator_id, emb_128_id,
|
||||
emb_64_single_id, multi_doc_ad_category_id,
|
||||
multi_doc_ad_category_id_mask, multi_doc_event_entity_id,
|
||||
multi_doc_event_entity_id_mask, multi_doc_ad_entity_id,
|
||||
multi_doc_ad_entity_id_mask, multi_doc_event_topic_id,
|
||||
multi_doc_event_topic_id_mask, multi_doc_event_category_id,
|
||||
multi_doc_event_category_id_mask, multi_doc_ad_topic_id,
|
||||
multi_doc_ad_topic_id_mask, display_id, ad_id,
|
||||
display_ad_and_is_leak, is_leak):
|
||||
"""
|
||||
Args:
|
||||
id_hldr: batch ids;
|
||||
wt_hldr: batch weights;
|
||||
"""
|
||||
|
||||
val_hldr = continue_val
|
||||
ind_hldr = indicator_id
|
||||
emb128_id_hldr = emb_128_id
|
||||
emb64_single_hldr = emb_64_single_id
|
||||
|
||||
ind_emb = self.gather_v2(self.emb64_indicator, ind_hldr, 0)
|
||||
ind_emb = self.flatten(ind_emb)
|
||||
|
||||
emb128_id_emb = self.gather_v2(self.emb128_embedding, emb128_id_hldr,
|
||||
0)
|
||||
emb128_id_emb = self.flatten(emb128_id_emb)
|
||||
|
||||
emb64_sgl_emb = self.gather_v2(self.emb64_single, emb64_single_hldr, 0)
|
||||
emb64_sgl_emb = self.flatten(emb64_sgl_emb)
|
||||
|
||||
mult_emb_1 = self.gather_v2(self.emb64_multi, multi_doc_ad_category_id,
|
||||
0)
|
||||
mult_emb_1 = self.mul(
|
||||
self.cast(mult_emb_1, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_ad_category_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_1 = self.reduceMean_false(mult_emb_1, 1)
|
||||
|
||||
mult_emb_2 = self.gather_v2(self.emb64_multi,
|
||||
multi_doc_event_entity_id, 0)
|
||||
mult_emb_2 = self.mul(
|
||||
self.cast(mult_emb_2, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_event_entity_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_2 = self.reduceMean_false(mult_emb_2, 1)
|
||||
|
||||
mult_emb_3 = self.gather_v2(self.emb64_multi, multi_doc_ad_entity_id,
|
||||
0)
|
||||
mult_emb_3 = self.mul(
|
||||
self.cast(mult_emb_3, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_ad_entity_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_3 = self.reduceMean_false(mult_emb_3, 1)
|
||||
|
||||
mult_emb_4 = self.gather_v2(self.emb64_multi, multi_doc_event_topic_id,
|
||||
0)
|
||||
mult_emb_4 = self.mul(
|
||||
self.cast(mult_emb_4, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_event_topic_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_4 = self.reduceMean_false(mult_emb_4, 1)
|
||||
|
||||
mult_emb_5 = self.gather_v2(self.emb64_multi,
|
||||
multi_doc_event_category_id, 0)
|
||||
mult_emb_5 = self.mul(
|
||||
self.cast(mult_emb_5, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_event_category_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_5 = self.reduceMean_false(mult_emb_5, 1)
|
||||
|
||||
mult_emb_6 = self.gather_v2(self.emb64_multi, multi_doc_ad_topic_id, 0)
|
||||
mult_emb_6 = self.mul(
|
||||
self.cast(mult_emb_6, mstype.float32),
|
||||
self.cast(self.expand_dims(multi_doc_ad_topic_id_mask, 2),
|
||||
mstype.float32))
|
||||
mult_emb_6 = self.reduceMean_false(mult_emb_6, 1)
|
||||
|
||||
mult_embedding = self.Concat((mult_emb_1, mult_emb_2, mult_emb_3,
|
||||
mult_emb_4, mult_emb_5, mult_emb_6))
|
||||
|
||||
input_embedding = self.Concat((val_hldr * 1, ind_emb, emb128_id_emb,
|
||||
emb64_sgl_emb, mult_embedding))
|
||||
deep_out = self.dense_layer_1(input_embedding)
|
||||
deep_out = self.dense_layer_2(deep_out)
|
||||
deep_out = self.dense_layer_3(deep_out)
|
||||
deep_out = self.dense_layer_4(deep_out)
|
||||
deep_out = self.dense_layer_5(deep_out)
|
||||
|
||||
deep_out = self.deep_predict(deep_out)
|
||||
|
||||
val_weight = self.mul(val_hldr,
|
||||
self.expand_dims(self.wide_continue_w, 0))
|
||||
|
||||
val_w_sum = self.reduce_sum_true(val_weight, 1)
|
||||
|
||||
ind_weight = self.gather_v2(self.wide_indicator_w, ind_hldr, 0)
|
||||
ind_w_sum = self.reduce_sum_true(ind_weight, 1)
|
||||
|
||||
emb128_id_weight = self.gather_v2(self.wide_emb128_w, emb128_id_hldr,
|
||||
0)
|
||||
emb128_w_sum = self.reduce_sum_true(emb128_id_weight, 1)
|
||||
|
||||
emb64_sgl_weight = self.gather_v2(self.wide_emb64_single_w,
|
||||
emb64_single_hldr, 0)
|
||||
emb64_w_sum = self.reduce_sum_true(emb64_sgl_weight, 1)
|
||||
|
||||
mult_weight_1 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_ad_category_id, 0)
|
||||
mult_weight_1 = self.mul(
|
||||
self.cast(mult_weight_1, mstype.float32),
|
||||
self.cast(multi_doc_ad_category_id_mask, mstype.float32))
|
||||
mult_weight_1 = self.reduce_sum_true(mult_weight_1, 1)
|
||||
|
||||
mult_weight_2 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_event_entity_id, 0)
|
||||
mult_weight_2 = self.mul(
|
||||
self.cast(mult_weight_2, mstype.float32),
|
||||
self.cast(multi_doc_event_entity_id_mask, mstype.float32))
|
||||
mult_weight_2 = self.reduce_sum_true(mult_weight_2, 1)
|
||||
|
||||
mult_weight_3 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_ad_entity_id, 0)
|
||||
mult_weight_3 = self.mul(
|
||||
self.cast(mult_weight_3, mstype.float32),
|
||||
self.cast(multi_doc_ad_entity_id_mask, mstype.float32))
|
||||
mult_weight_3 = self.reduce_sum_true(mult_weight_3, 1)
|
||||
|
||||
mult_weight_4 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_event_topic_id, 0)
|
||||
mult_weight_4 = self.mul(
|
||||
self.cast(mult_weight_4, mstype.float32),
|
||||
self.cast(multi_doc_event_topic_id_mask, mstype.float32))
|
||||
mult_weight_4 = self.reduce_sum_true(mult_weight_4, 1)
|
||||
|
||||
mult_weight_5 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_event_category_id, 0)
|
||||
mult_weight_5 = self.mul(
|
||||
self.cast(mult_weight_5, mstype.float32),
|
||||
self.cast(multi_doc_event_category_id_mask, mstype.float32))
|
||||
mult_weight_5 = self.reduce_sum_true(mult_weight_5, 1)
|
||||
|
||||
mult_weight_6 = self.gather_v2(self.wide_emb64_multi_w,
|
||||
multi_doc_ad_topic_id, 0)
|
||||
|
||||
mult_weight_6 = self.mul(
|
||||
self.cast(mult_weight_6, mstype.float32),
|
||||
self.cast(multi_doc_ad_topic_id_mask, mstype.float32))
|
||||
mult_weight_6 = self.reduce_sum_true(mult_weight_6, 1)
|
||||
|
||||
mult_weight_sum = mult_weight_1 + mult_weight_2 + mult_weight_3 + mult_weight_4 + mult_weight_5 + mult_weight_6
|
||||
|
||||
wide_out = self.BiasAdd(
|
||||
val_w_sum + ind_w_sum + emb128_w_sum + emb64_w_sum +
|
||||
mult_weight_sum, self.wide_bias)
|
||||
|
||||
out = wide_out + deep_out
|
||||
return out, self.emb128_embedding, self.emb64_single, self.emb64_multi
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
""""
|
||||
Provide WideDeep training loss through network.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network
|
||||
config (Class): WideDeep config
|
||||
"""
|
||||
def __init__(self, network, config):
|
||||
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.l2_coef = config.l2_coef
|
||||
|
||||
self.loss = P.SigmoidCrossEntropyWithLogits()
|
||||
self.square = P.Square()
|
||||
self.reduceMean_false = P.ReduceMean(keep_dims=False)
|
||||
self.reduceSum_false = P.ReduceSum(keep_dims=False)
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, label, continue_val, indicator_id, emb_128_id,
|
||||
emb_64_single_id, multi_doc_ad_category_id,
|
||||
multi_doc_ad_category_id_mask, multi_doc_event_entity_id,
|
||||
multi_doc_event_entity_id_mask, multi_doc_ad_entity_id,
|
||||
multi_doc_ad_entity_id_mask, multi_doc_event_topic_id,
|
||||
multi_doc_event_topic_id_mask, multi_doc_event_category_id,
|
||||
multi_doc_event_category_id_mask, multi_doc_ad_topic_id,
|
||||
multi_doc_ad_topic_id_mask, display_id, ad_id,
|
||||
display_ad_and_is_leak, is_leak):
|
||||
"""
|
||||
NetWithLossClass construct
|
||||
"""
|
||||
# emb128_embedding, emb64_single, emb64_multi
|
||||
predict, _, _, _ = self.network(
|
||||
continue_val, indicator_id, emb_128_id, emb_64_single_id,
|
||||
multi_doc_ad_category_id, multi_doc_ad_category_id_mask,
|
||||
multi_doc_event_entity_id, multi_doc_event_entity_id_mask,
|
||||
multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask,
|
||||
multi_doc_event_topic_id, multi_doc_event_topic_id_mask,
|
||||
multi_doc_event_category_id, multi_doc_event_category_id_mask,
|
||||
multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id,
|
||||
ad_id, display_ad_and_is_leak, is_leak)
|
||||
|
||||
predict = self.reshape(predict, (-1,))
|
||||
basic_loss = self.loss(predict, label)
|
||||
wide_loss = self.reduceMean_false(basic_loss)
|
||||
deep_loss = self.reduceMean_false(basic_loss)
|
||||
return wide_loss, deep_loss
|
||||
|
||||
|
||||
class IthOutputCell(nn.Cell):
|
||||
"""
|
||||
IthOutputCell
|
||||
"""
|
||||
def __init__(self, network, output_index):
|
||||
super(IthOutputCell, self).__init__()
|
||||
self.network = network
|
||||
self.output_index = output_index
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13,
|
||||
x14, x15, x16, x17, x18, x19, x20, x21):
|
||||
"""
|
||||
IthOutputCell construct
|
||||
"""
|
||||
predict = self.network(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
|
||||
x12, x13, x14, x15, x16, x17, x18, x19, x20,
|
||||
x21)[self.output_index]
|
||||
return predict
|
||||
|
||||
|
||||
class TrainStepWrap(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of WideDeep network training.
|
||||
|
||||
Append Adam and FTRL optimizers to the training network after that construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): the training network. Note that loss function should have been added.
|
||||
sens (Number): The adjust parameter. Default: 1000.0
|
||||
"""
|
||||
def __init__(self, network, config, sens=1000.0):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
self.trainable_params = network.trainable_params()
|
||||
weights_w = []
|
||||
weights_d = []
|
||||
for params in self.trainable_params:
|
||||
if 'wide' in params.name:
|
||||
weights_w.append(params)
|
||||
else:
|
||||
weights_d.append(params)
|
||||
|
||||
self.weights_w = ParameterTuple(weights_w)
|
||||
self.weights_d = ParameterTuple(weights_d)
|
||||
self.optimizer_w = FTRL(learning_rate=config.ftrl_lr,
|
||||
params=self.weights_w,
|
||||
l1=5e-4,
|
||||
l2=5e-4,
|
||||
initial_accum=0.1,
|
||||
loss_scale=sens)
|
||||
|
||||
#self.optimizer_d = ProximalAdagrad(self.weights_d, learning_rate=config.adam_lr,loss_scale=sens)
|
||||
self.optimizer_d = Adam(self.weights_d,
|
||||
learning_rate=config.adam_lr,
|
||||
eps=1e-6,
|
||||
loss_scale=sens)
|
||||
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
self.grad_w = C.GradOperation('grad_w',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.grad_d = C.GradOperation('grad_d',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
|
||||
self.sens = sens
|
||||
self.loss_net_w = IthOutputCell(network, output_index=0)
|
||||
self.loss_net_d = IthOutputCell(network, output_index=1)
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_w = None
|
||||
self.grad_reducer_d = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL,
|
||||
ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_w = DistributedGradReducer(
|
||||
self.optimizer_w.parameters, mean, degree)
|
||||
self.grad_reducer_d = DistributedGradReducer(
|
||||
self.optimizer_d.parameters, mean, degree)
|
||||
|
||||
def construct(self, label, continue_val, indicator_id, emb_128_id,
|
||||
emb_64_single_id, multi_doc_ad_category_id,
|
||||
multi_doc_ad_category_id_mask, multi_doc_event_entity_id,
|
||||
multi_doc_event_entity_id_mask, multi_doc_ad_entity_id,
|
||||
multi_doc_ad_entity_id_mask, multi_doc_event_topic_id,
|
||||
multi_doc_event_topic_id_mask, multi_doc_event_category_id,
|
||||
multi_doc_event_category_id_mask, multi_doc_ad_topic_id,
|
||||
multi_doc_ad_topic_id_mask, display_id, ad_id,
|
||||
display_ad_and_is_leak, is_leak):
|
||||
"""
|
||||
TrainStepWrap construct
|
||||
"""
|
||||
weights_w = self.weights_w
|
||||
weights_d = self.weights_d
|
||||
loss_w, loss_d = self.network(
|
||||
label, continue_val, indicator_id, emb_128_id, emb_64_single_id,
|
||||
multi_doc_ad_category_id, multi_doc_ad_category_id_mask,
|
||||
multi_doc_event_entity_id, multi_doc_event_entity_id_mask,
|
||||
multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask,
|
||||
multi_doc_event_topic_id, multi_doc_event_topic_id_mask,
|
||||
multi_doc_event_category_id, multi_doc_event_category_id_mask,
|
||||
multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id,
|
||||
ad_id, display_ad_and_is_leak, is_leak)
|
||||
|
||||
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) #
|
||||
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) #
|
||||
grads_w = self.grad_w(self.loss_net_w, weights_w)(
|
||||
label, continue_val, indicator_id, emb_128_id, emb_64_single_id,
|
||||
multi_doc_ad_category_id, multi_doc_ad_category_id_mask,
|
||||
multi_doc_event_entity_id, multi_doc_event_entity_id_mask,
|
||||
multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask,
|
||||
multi_doc_event_topic_id, multi_doc_event_topic_id_mask,
|
||||
multi_doc_event_category_id, multi_doc_event_category_id_mask,
|
||||
multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id,
|
||||
ad_id, display_ad_and_is_leak, is_leak, sens_w)
|
||||
grads_d = self.grad_d(self.loss_net_d, weights_d)(
|
||||
label, continue_val, indicator_id, emb_128_id, emb_64_single_id,
|
||||
multi_doc_ad_category_id, multi_doc_ad_category_id_mask,
|
||||
multi_doc_event_entity_id, multi_doc_event_entity_id_mask,
|
||||
multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask,
|
||||
multi_doc_event_topic_id, multi_doc_event_topic_id_mask,
|
||||
multi_doc_event_category_id, multi_doc_event_category_id_mask,
|
||||
multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id,
|
||||
ad_id, display_ad_and_is_leak, is_leak, sens_d)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads_w = self.grad_reducer_w(grads_w)
|
||||
grads_d = self.grad_reducer_d(grads_d)
|
||||
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(
|
||||
loss_d, self.optimizer_d(grads_d))
|
||||
|
||||
|
||||
class PredictWithSigmoid(nn.Cell):
|
||||
"""
|
||||
PredictWithSigomid
|
||||
"""
|
||||
def __init__(self, network):
|
||||
super(PredictWithSigmoid, self).__init__()
|
||||
self.network = network
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, label, continue_val, indicator_id, emb_128_id,
|
||||
emb_64_single_id, multi_doc_ad_category_id,
|
||||
multi_doc_ad_category_id_mask, multi_doc_event_entity_id,
|
||||
multi_doc_event_entity_id_mask, multi_doc_ad_entity_id,
|
||||
multi_doc_ad_entity_id_mask, multi_doc_event_topic_id,
|
||||
multi_doc_event_topic_id_mask, multi_doc_event_category_id,
|
||||
multi_doc_event_category_id_mask, multi_doc_ad_topic_id,
|
||||
multi_doc_ad_topic_id_mask, display_id, ad_id,
|
||||
display_ad_and_is_leak, is_leak):
|
||||
"""
|
||||
PredictWithSigomid construct
|
||||
"""
|
||||
logits, _, _, _ = self.network(
|
||||
continue_val, indicator_id, emb_128_id, emb_64_single_id,
|
||||
multi_doc_ad_category_id, multi_doc_ad_category_id_mask,
|
||||
multi_doc_event_entity_id, multi_doc_event_entity_id_mask,
|
||||
multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask,
|
||||
multi_doc_event_topic_id, multi_doc_event_topic_id_mask,
|
||||
multi_doc_event_category_id, multi_doc_event_category_id_mask,
|
||||
multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id,
|
||||
ad_id, display_ad_and_is_leak, is_leak)
|
||||
logits = self.reshape(logits, (-1,))
|
||||
pred_probs = self.sigmoid(logits)
|
||||
return logits, pred_probs, label, display_id
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" training_and_evaluating """
|
||||
|
||||
import os
|
||||
import sys
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
from src.datasets import create_dataset, compute_emb_dim
|
||||
from src.metrics import AUCMetric
|
||||
from src.config import WideDeepConfig
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def get_WideDeep_net(config):
|
||||
"""
|
||||
Get network of wide&deep model.
|
||||
"""
|
||||
WideDeep_net = WideDeepModel(config)
|
||||
|
||||
loss_net = NetWithLossClass(WideDeep_net, config)
|
||||
train_net = TrainStepWrap(loss_net, config)
|
||||
eval_net = PredictWithSigmoid(WideDeep_net)
|
||||
|
||||
return train_net, eval_net
|
||||
|
||||
|
||||
class ModelBuilder():
|
||||
"""
|
||||
ModelBuilder.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hook(self):
|
||||
pass
|
||||
|
||||
def get_train_hook(self):
|
||||
hooks = []
|
||||
callback = LossCallBack()
|
||||
hooks.append(callback)
|
||||
|
||||
if int(os.getenv('DEVICE_ID')) == 0:
|
||||
pass
|
||||
return hooks
|
||||
|
||||
def get_net(self, config):
|
||||
return get_WideDeep_net(config)
|
||||
|
||||
def train_and_eval(config):
|
||||
"""
|
||||
train_and_eval.
|
||||
"""
|
||||
data_path = config.data_path
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset)
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset)
|
||||
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
||||
net_builder = ModelBuilder()
|
||||
|
||||
train_net, eval_net = net_builder.get_net(config)
|
||||
train_net.set_train()
|
||||
auc_metric = AUCMetric()
|
||||
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||
callback = LossCallBack(config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
||||
keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
|
||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
|
||||
callback, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wide_and_deep_config = WideDeepConfig()
|
||||
wide_and_deep_config.argparse_init()
|
||||
compute_emb_dim(wide_and_deep_config)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
|
||||
save_graphs=True)
|
||||
train_and_eval(wide_and_deep_config)
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" training_multinpu"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.train import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
from src.datasets import create_dataset, compute_emb_dim
|
||||
from src.metrics import AUCMetric
|
||||
from src.config import WideDeepConfig
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
def get_WideDeep_net(config):
|
||||
"""
|
||||
get_WideDeep_net
|
||||
"""
|
||||
WideDeep_net = WideDeepModel(config)
|
||||
|
||||
loss_net = NetWithLossClass(WideDeep_net, config)
|
||||
train_net = TrainStepWrap(loss_net, config)
|
||||
eval_net = PredictWithSigmoid(WideDeep_net)
|
||||
|
||||
return train_net, eval_net
|
||||
|
||||
|
||||
class ModelBuilder():
|
||||
"""
|
||||
ModelBuilder
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_hook(self):
|
||||
pass
|
||||
|
||||
def get_train_hook(self):
|
||||
hooks = []
|
||||
callback = LossCallBack()
|
||||
hooks.append(callback)
|
||||
|
||||
if int(os.getenv('DEVICE_ID')) == 0:
|
||||
pass
|
||||
return hooks
|
||||
|
||||
def get_net(self, config):
|
||||
return get_WideDeep_net(config)
|
||||
|
||||
def train_and_eval(config):
|
||||
"""
|
||||
train_and_eval
|
||||
"""
|
||||
data_path = config.data_path
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
|
||||
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
|
||||
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset,
|
||||
rank_id=get_rank(), rank_size=get_group_size())
|
||||
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
|
||||
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset,
|
||||
rank_id=get_rank(), rank_size=get_group_size())
|
||||
|
||||
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
|
||||
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
|
||||
|
||||
net_builder = ModelBuilder()
|
||||
|
||||
train_net, eval_net = net_builder.get_net(config)
|
||||
train_net.set_train()
|
||||
auc_metric = AUCMetric()
|
||||
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
||||
callback = LossCallBack(config)
|
||||
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
|
||||
keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
|
||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
|
||||
callback, ckpoint_cb])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wide_and_deep_config = WideDeepConfig()
|
||||
wide_and_deep_config.argparse_init()
|
||||
compute_emb_dim(wide_and_deep_config)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
|
||||
save_graphs=True)
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=get_group_size())
|
||||
train_and_eval(wide_and_deep_config)
|
Loading…
Reference in New Issue