!7055 update deepfm network.

Merge pull request !7055 from linqingke/deepfm
This commit is contained in:
mindspore-ci-bot 2020-10-15 16:07:43 +08:00 committed by Gitee
commit 9ccf5cb2b5
2 changed files with 422 additions and 434 deletions

View File

@ -1,62 +1,55 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""
Define parameters of dataset.
"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 1000
data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 2
class ModelConfig:
"""
Define parameters of model.
"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[400, 400, 512], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
class TrainConfig:
"""
Define parameters of training.
"""
batch_size = DataConfig.batch_size
l2_coef = 1e-6
learning_rate = 1e-5
epsilon = 1e-8
loss_scale = 1024.0
train_epochs = 15
save_checkpoint = True
ckpt_file_name_prefix = "deepfm"
save_checkpoint_steps = 1
keep_checkpoint_max = 15
eval_callback = True
loss_callback = True
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""data config"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 16000
data_field_size = 39
data_format = 2
class ModelConfig:
"""model config"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[1024, 512, 256, 128], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
class TrainConfig:
"""train config"""
batch_size = DataConfig.batch_size
l2_coef = 8e-5
learning_rate = 5e-4
epsilon = 5e-8
loss_scale = 1024.0
train_epochs = 5
save_checkpoint = True
ckpt_file_name_prefix = "deepfm"
save_checkpoint_steps = 1
keep_checkpoint_max = 50
eval_callback = True
loss_callback = True

View File

@ -1,372 +1,367 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_training """
import os
import numpy as np
from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.nn.metrics import Metric
from mindspore import nn, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer, Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from .callback import EvalCallBack, LossCallBack
np_type = np.float32
ms_type = mstype.float32
class AUCMetric(Metric):
"""AUC metric for DeepFM model."""
def __init__(self):
super(AUCMetric, self).__init__()
self.pred_probs = []
self.true_labels = []
def clear(self):
"""Clear the internal evaluation result."""
self.pred_probs = []
self.true_labels = []
def update(self, *inputs):
batch_predict = inputs[1].asnumpy()
batch_label = inputs[2].asnumpy()
self.pred_probs.extend(batch_predict.flatten().tolist())
self.true_labels.extend(batch_label.flatten().tolist())
def eval(self):
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
return auc
def init_method(method, shape, name, max_val=0.01):
"""
The method of init parameters.
Args:
method (str): The method uses to initialize parameter.
shape (list): The shape of parameter.
name (str): The name of parameter.
max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.
Returns:
Parameter.
"""
if method in ['random', 'uniform']:
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
return params
def init_var_dict(init_args, values):
"""
Init parameter.
Args:
init_args (list): Define max and min value of parameters.
values (list): Define name, shape and init method of parameters.
Returns:
dict, a dict ot Parameter.
"""
var_map = {}
_, _max_val = init_args
for key, shape, init_flag in values:
if key not in var_map.keys():
if init_flag in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
elif init_flag == "one":
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
elif init_flag == "zero":
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
elif init_flag == 'normal':
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
return var_map
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of DeepFM Model;
Containing: activation, matmul, bias_add;
Args:
input_dim (int): the shape of weight at 0-aixs;
output_dim (int): the shape of weight at 1-aixs, and shape of bias
weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal";
act_str (str): activation function method, "relu", "sigmoid", "tanh";
keep_prob (float): Dropout Layer keep_prob_rate;
scale_coef (float): input scale coefficient;
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=keep_prob)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
def _init_activation(self, act_str):
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
"""Dense Layer for Deep Layer of DeepFM Model."""
x = self.act_func(x)
if self.training:
x = self.dropout(x)
x = self.mul(x, self.scale_coef)
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
wx = self.matmul(x, weight)
wx = self.cast(wx, mstype.float32)
wx = self.realDiv(wx, self.scale_coef)
output = self.bias_add(wx, self.bias)
return output
class DeepFMModel(nn.Cell):
"""
From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction"
Args:
batch_size (int): smaple_number of per step in training; (int, batch_size=128)
filed_size (int): input filed number, or called id_feature number; (int, filed_size=39)
vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000)
emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator;
(int, deep_layer_args=[[100, 100, 100], "relu"])
init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds])
weight_bias_init (list): weight, bias init method for deep layers;
(list[str], weight_bias_init=['random', 'zero'])
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
"""
def __init__(self, config):
super(DeepFMModel, self).__init__()
self.batch_size = config.batch_size
self.field_size = config.data_field_size
self.vocab_size = config.data_vocab_size
self.emb_dim = config.data_emb_dim
self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
self.init_args = config.init_args
self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
('b', [1], 'normal')]
var_map = init_var_dict(self.init_args, init_acts)
self.fm_w = var_map["W_l2"]
self.fm_b = var_map["b"]
self.embedding_table = var_map["V_l2"]
# Deep Layers
self.deep_input_dims = self.field_size * self.emb_dim + 1
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
# FM, linear Layers
self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul()
self.ReduceSum = P.ReduceSum(keep_dims=False)
self.Reshape = P.Reshape()
self.Square = P.Square()
self.Shape = P.Shape()
self.Tile = P.Tile()
self.Concat = P.Concat(axis=1)
self.Cast = P.Cast()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids; [bs, field_size]
wt_hldr: batch weights; [bs, field_size]
"""
mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Linear layer
fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
wx = self.Mul(fm_id_weight, mask)
linear_out = self.ReduceSum(wx, 1)
# FM layer
fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0)
vx = self.Mul(fm_id_embs, mask)
v1 = self.ReduceSum(vx, 1)
v1 = self.Square(v1)
v2 = self.Square(vx)
v2 = self.ReduceSum(v2, 1)
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
fm_out = self.Reshape(fm_out, (-1, 1))
# Deep layer
b = self.Reshape(self.fm_b, (1, 1))
b = self.Tile(b, (self.batch_size, 1))
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.Concat((deep_in, b))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_out = self.dense_layer_4(deep_in)
out = linear_out + fm_out + deep_out
return out, fm_id_weight, fm_id_embs
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition.
"""
def __init__(self, network, l2_coef=1e-6):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = P.SigmoidCrossEntropyWithLogits()
self.network = network
self.l2_coef = l2_coef
self.Square = P.Square()
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
self.ReduceSum_false = P.ReduceSum(keep_dims=False)
def construct(self, batch_ids, batch_wts, label):
predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
mean_log_loss = self.ReduceMean_false(log_loss)
l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5
loss = mean_log_loss + l2_loss_all
return loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Eval model with sigmoid.
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels
class ModelBuilder:
"""
Model builder for DeepFM.
Args:
model_config (ModelConfig): Model configuration.
train_config (TrainConfig): Train configuration.
"""
def __init__(self, model_config, train_config):
self.model_config = model_config
self.train_config = train_config
def get_callback_list(self, model=None, eval_dataset=None):
"""
Get callbacks which contains checkpoint callback, eval callback and loss callback.
Args:
model (Cell): The network is added callback (default=None).
eval_dataset (Dataset): Dataset for eval (default=None).
"""
callback_list = []
if self.train_config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
keep_checkpoint_max=self.train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
directory=self.train_config.output_path,
config=config_ck)
callback_list.append(ckpt_cb)
if self.train_config.eval_callback:
if model is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
self.train_config.eval_callback, model))
if eval_dataset is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() "
"args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset))
auc_metric = AUCMetric()
eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
eval_file_path=os.path.join(self.train_config.output_path,
self.train_config.eval_file_name))
callback_list.append(eval_callback)
if self.train_config.loss_callback:
loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
self.train_config.loss_file_name))
callback_list.append(loss_callback)
if callback_list:
return callback_list
return None
def get_train_eval_net(self):
deepfm_net = DeepFMModel(self.model_config)
loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef)
train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
eps=self.train_config.epsilon,
loss_scale=self.train_config.loss_scale)
eval_net = PredictWithSigmoid(deepfm_net)
return train_net, eval_net
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_training """
import os
import numpy as np
from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam
from mindspore.nn.metrics import Metric
from mindspore import nn, Tensor, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from src.callback import EvalCallBack, LossCallBack
np_type = np.float32
ms_type = mstype.float32
class AUCMetric(Metric):
"""
Metric method
"""
def __init__(self):
super(AUCMetric, self).__init__()
self.pred_probs = []
self.true_labels = []
def clear(self):
"""Clear the internal evaluation result."""
self.pred_probs = []
self.true_labels = []
def update(self, *inputs):
batch_predict = inputs[1].asnumpy()
batch_label = inputs[2].asnumpy()
self.pred_probs.extend(batch_predict.flatten().tolist())
self.true_labels.extend(batch_label.flatten().tolist())
def eval(self):
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
auc = roc_auc_score(self.true_labels, self.pred_probs)
return auc
def init_method(method, shape, name, max_val=1.0):
if method in ['uniform']:
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
elif method == "one":
params = Parameter(initializer("ones", shape, ms_type), name=name)
elif method == 'zero':
params = Parameter(initializer("zeros", shape, ms_type), name=name)
elif method == "normal":
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
return params
def init_var_dict(init_args, var_list):
""" init var with different methods. """
var_map = {}
_, max_val = init_args
for i, _ in enumerate(var_list):
key, shape, method = var_list[i]
if key not in var_map.keys():
if method in ['random', 'uniform']:
var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key)
elif method == "one":
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
elif method == "zero":
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
elif method == 'normal':
var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).
astype(dtype=np_type)), name=key)
return var_map
#
class DenseLayer(nn.Cell):
"""
Dense Layer for Deep Layer of WideDeep Model;
Containing: activation, matmul, bias_add;
Args:
"""
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True,
use_act=True):
super(DenseLayer, self).__init__()
weight_init, bias_init = weight_bias_init
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
self.bias = init_method(bias_init, [output_dim], name="bias")
self.act_func = self._init_activation(act_str)
self.matmul = P.MatMul(transpose_b=False)
self.bias_add = P.BiasAdd()
self.cast = P.Cast()
self.dropout = Dropout(keep_prob=1.0)
self.mul = P.Mul()
self.realDiv = P.RealDiv()
self.scale_coef = scale_coef
self.convert_dtype = convert_dtype
self.use_act = use_act
def _init_activation(self, act_str):
"""Init activation function"""
act_str = act_str.lower()
if act_str == "relu":
act_func = P.ReLU()
elif act_str == "sigmoid":
act_func = P.Sigmoid()
elif act_str == "tanh":
act_func = P.Tanh()
return act_func
def construct(self, x):
"""Construct function"""
x = self.dropout(x)
if self.convert_dtype:
x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16)
bias = self.cast(self.bias, mstype.float16)
wx = self.matmul(x, weight)
wx = self.bias_add(wx, bias)
if self.use_act:
wx = self.act_func(wx)
wx = self.cast(wx, mstype.float32)
else:
wx = self.matmul(x, self.weight)
wx = self.bias_add(wx, self.bias)
if self.use_act:
wx = self.act_func(wx)
return wx
class DeepFMModel(nn.Cell):
"""
From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction"
Args:
batch_size(int): smaple_number of per step in training; (int, batch_size=128)
filed_size(int): input filed number, or called id_feature number; (int, filed_size=39)
vocab_size(int): id_feature vocab size, id dict size; (int, vocab_size=200000)
emb_dim(int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
deep_layer_args(list): Deep Layer args, layer_dim_list, layer_activator;
(int, deep_layer_args=[[100, 100, 100], "relu"])
init_args(list): init args for Parameter init; (list, init_args=[min, max, seeds])
weight_bias_init(list): weight, bias init method for deep layers;
(list[str], weight_bias_init=['random', 'zero'])
keep_prob(float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
"""
def __init__(self, config):
super(DeepFMModel, self).__init__()
self.batch_size = config.batch_size
self.field_size = config.data_field_size
self.vocab_size = config.data_vocab_size
self.emb_dim = config.data_emb_dim
self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
self.init_args = config.init_args
self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
var_map = init_var_dict(self.init_args, init_acts)
self.fm_w = var_map["W_l2"]
self.embedding_table = var_map["V_l2"]
" Deep Layers "
self.deep_input_dims = self.field_size * self.emb_dim
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
" FM, linear Layers "
self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul()
self.ReduceSum = P.ReduceSum(keep_dims=False)
self.Reshape = P.Reshape()
self.Square = P.Square()
self.Shape = P.Shape()
self.Tile = P.Tile()
self.Concat = P.Concat(axis=1)
self.Cast = P.Cast()
def construct(self, id_hldr, wt_hldr):
"""
Args:
id_hldr: batch ids; [bs, field_size]
wt_hldr: batch weights; [bs, field_size]
"""
mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Linear layer
fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
wx = self.Mul(fm_id_weight, mask)
linear_out = self.ReduceSum(wx, 1)
# FM layer
fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0)
vx = self.Mul(fm_id_embs, mask)
v1 = self.ReduceSum(vx, 1)
v1 = self.Square(v1)
v2 = self.Square(vx)
v2 = self.ReduceSum(v2, 1)
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
fm_out = self.Reshape(fm_out, (-1, 1))
# Deep layer
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)
deep_in = self.dense_layer_2(deep_in)
deep_in = self.dense_layer_3(deep_in)
deep_in = self.dense_layer_4(deep_in)
deep_out = self.dense_layer_5(deep_in)
out = linear_out + fm_out + deep_out
return out, self.fm_w, self.embedding_table
class NetWithLossClass(nn.Cell):
"""
NetWithLossClass definition
"""
def __init__(self, network, l2_coef=1e-6):
super(NetWithLossClass, self).__init__(auto_prefix=False)
self.loss = P.SigmoidCrossEntropyWithLogits()
self.network = network
self.l2_coef = l2_coef
self.Square = P.Square()
self.ReduceMean_false = P.ReduceMean(keep_dims=False)
self.ReduceSum_false = P.ReduceSum(keep_dims=False)
#
def construct(self, batch_ids, batch_wts, label):
predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts)
log_loss = self.loss(predict, label)
mean_log_loss = self.ReduceMean_false(log_loss)
l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5
loss = mean_log_loss + l2_loss_all
return loss
class TrainStepWrap(nn.Cell):
"""
TrainStepWrap definition
"""
def __init__(self, network, lr, eps, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network
self.network.set_train()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = loss_scale
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label):
weights = self.weights
loss = self.network(batch_ids, batch_wts, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
class PredictWithSigmoid(nn.Cell):
"""
Predict method
"""
def __init__(self, network):
super(PredictWithSigmoid, self).__init__(auto_prefix=False)
self.network = network
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels
class ModelBuilder():
"""
Model builder for DeepFM.
Args:
model_config (ModelConfig): Model configuration.
train_config (TrainConfig): Train configuration.
"""
def __init__(self, model_config, train_config):
self.model_config = model_config
self.train_config = train_config
def get_callback_list(self, model=None, eval_dataset=None):
"""
Get callbacks which contains checkpoint callback, eval callback and loss callback.
Args:
model (Cell): The network is added callback (default=None)
eval_dataset (Dataset): Dataset for eval (default=None)
"""
callback_list = []
if self.train_config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
keep_checkpoint_max=self.train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
directory=self.train_config.output_path,
config=config_ck)
callback_list.append(ckpt_cb)
if self.train_config.eval_callback:
if model is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
self.train_config.eval_callback, model))
if eval_dataset is None:
raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args eval_dataset is {}".
format(self.train_config.eval_callback, eval_dataset))
auc_metric = AUCMetric()
eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
eval_file_path=os.path.join(self.train_config.output_path,
self.train_config.eval_file_name))
callback_list.append(eval_callback)
if self.train_config.loss_callback:
loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
self.train_config.loss_file_name))
callback_list.append(loss_callback)
if callback_list:
return callback_list
return None
def get_train_eval_net(self):
deepfm_net = DeepFMModel(self.model_config)
loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef)
train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
eps=self.train_config.epsilon, loss_scale=self.train_config.loss_scale)
eval_net = PredictWithSigmoid(deepfm_net)
return train_net, eval_net