forked from mindspore-Ecosystem/mindspore
optimize the deepfm test case network structure
This commit is contained in:
parent
9b2b062642
commit
02650e2655
|
@ -24,7 +24,7 @@ class DataConfig:
|
|||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 1000
|
||||
batch_size = 16000
|
||||
data_field_size = 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format = 3
|
||||
|
@ -38,7 +38,7 @@ class ModelConfig:
|
|||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[400, 400, 512], "relu"]
|
||||
deep_layer_args = [[1024, 512, 256, 128], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
|
@ -49,9 +49,9 @@ class TrainConfig:
|
|||
Define parameters of training.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 1e-6
|
||||
learning_rate = 1e-5
|
||||
epsilon = 1e-8
|
||||
l2_coef = 8e-5
|
||||
learning_rate = 5e-4
|
||||
epsilon = 5e-8
|
||||
loss_scale = 1024.0
|
||||
train_epochs = 3
|
||||
save_checkpoint = True
|
||||
|
|
|
@ -212,7 +212,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100
|
|||
np.array(y).flatten().reshape(batch_size, 39),
|
||||
np.array(z).flatten().reshape(batch_size, 1))),
|
||||
input_columns=['feat_ids', 'feat_vals', 'label'],
|
||||
columns_order=['feat_ids', 'feat_vals', 'label'],
|
||||
column_order=['feat_ids', 'feat_vals', 'label'],
|
||||
num_parallel_workers=8)
|
||||
ds = ds.repeat(epochs)
|
||||
return ds
|
||||
|
|
|
@ -24,9 +24,12 @@ from mindspore.ops import operations as P
|
|||
from mindspore.nn import Dropout
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore import nn, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import Uniform, initializer, Normal
|
||||
from mindspore import nn, Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
from .callback import EvalCallBack, LossCallBack
|
||||
|
||||
|
@ -60,7 +63,7 @@ class AUCMetric(Metric):
|
|||
return auc
|
||||
|
||||
|
||||
def init_method(method, shape, name, max_val=0.01):
|
||||
def init_method(method, shape, name, max_val=1.0):
|
||||
"""
|
||||
The method of init parameters.
|
||||
|
||||
|
@ -73,18 +76,18 @@ def init_method(method, shape, name, max_val=0.01):
|
|||
Returns:
|
||||
Parameter.
|
||||
"""
|
||||
if method in ['random', 'uniform']:
|
||||
if method in ['uniform']:
|
||||
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
|
||||
elif method == "one":
|
||||
params = Parameter(initializer("ones", shape, ms_type), name=name)
|
||||
elif method == 'zero':
|
||||
params = Parameter(initializer("zeros", shape, ms_type), name=name)
|
||||
elif method == "normal":
|
||||
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
|
||||
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
|
||||
return params
|
||||
|
||||
|
||||
def init_var_dict(init_args, values):
|
||||
def init_var_dict(init_args, var_list):
|
||||
"""
|
||||
Init parameter.
|
||||
|
||||
|
@ -96,17 +99,19 @@ def init_var_dict(init_args, values):
|
|||
dict, a dict ot Parameter.
|
||||
"""
|
||||
var_map = {}
|
||||
_, _max_val = init_args
|
||||
for key, shape, init_flag in values:
|
||||
_, max_val = init_args
|
||||
for i, _ in enumerate(var_list):
|
||||
key, shape, method = var_list[i]
|
||||
if key not in var_map.keys():
|
||||
if init_flag in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
|
||||
elif init_flag == "one":
|
||||
if method in ['random', 'uniform']:
|
||||
var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key)
|
||||
elif method == "one":
|
||||
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
|
||||
elif init_flag == "zero":
|
||||
elif method == "zero":
|
||||
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
|
||||
elif init_flag == 'normal':
|
||||
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
|
||||
elif method == 'normal':
|
||||
var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).
|
||||
astype(dtype=np_type)), name=key)
|
||||
return var_map
|
||||
|
||||
|
||||
|
@ -122,7 +127,9 @@ class DenseLayer(nn.Cell):
|
|||
keep_prob (float): Dropout Layer keep_prob_rate;
|
||||
scale_coef (float): input scale coefficient;
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
|
||||
|
||||
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True,
|
||||
use_act=True):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
|
||||
|
@ -131,12 +138,15 @@ class DenseLayer(nn.Cell):
|
|||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.dropout = Dropout(keep_prob=1.0)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
self.convert_dtype = convert_dtype
|
||||
self.use_act = use_act
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
"""Init activation function"""
|
||||
act_str = act_str.lower()
|
||||
if act_str == "relu":
|
||||
act_func = P.ReLU()
|
||||
|
@ -147,17 +157,23 @@ class DenseLayer(nn.Cell):
|
|||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
x = self.act_func(x)
|
||||
if self.training:
|
||||
x = self.dropout(x)
|
||||
x = self.mul(x, self.scale_coef)
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
wx = self.realDiv(wx, self.scale_coef)
|
||||
output = self.bias_add(wx, self.bias)
|
||||
return output
|
||||
"""Construct function"""
|
||||
x = self.dropout(x)
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
bias = self.cast(self.bias, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.bias_add(wx, bias)
|
||||
if self.use_act:
|
||||
wx = self.act_func(wx)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
else:
|
||||
wx = self.matmul(x, self.weight)
|
||||
wx = self.bias_add(wx, self.bias)
|
||||
if self.use_act:
|
||||
wx = self.act_func(wx)
|
||||
return wx
|
||||
|
||||
|
||||
class DeepFMModel(nn.Cell):
|
||||
|
@ -176,6 +192,7 @@ class DeepFMModel(nn.Cell):
|
|||
(list[str], weight_bias_init=['random', 'zero'])
|
||||
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(DeepFMModel, self).__init__()
|
||||
|
||||
|
@ -188,24 +205,24 @@ class DeepFMModel(nn.Cell):
|
|||
self.weight_bias_init = config.weight_bias_init
|
||||
self.keep_prob = config.keep_prob
|
||||
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
|
||||
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
|
||||
('b', [1], 'normal')]
|
||||
('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
|
||||
var_map = init_var_dict(self.init_args, init_acts)
|
||||
self.fm_w = var_map["W_l2"]
|
||||
self.fm_b = var_map["b"]
|
||||
self.embedding_table = var_map["V_l2"]
|
||||
# Deep Layers
|
||||
self.deep_input_dims = self.field_size * self.emb_dim + 1
|
||||
" Deep Layers "
|
||||
self.deep_input_dims = self.field_size * self.emb_dim
|
||||
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
|
||||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
|
||||
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
||||
# FM, linear Layers
|
||||
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
||||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
|
||||
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
|
||||
" FM, linear Layers "
|
||||
self.Gatherv2 = P.GatherV2()
|
||||
self.Mul = P.Mul()
|
||||
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
||||
|
@ -238,16 +255,14 @@ class DeepFMModel(nn.Cell):
|
|||
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
|
||||
fm_out = self.Reshape(fm_out, (-1, 1))
|
||||
# Deep layer
|
||||
b = self.Reshape(self.fm_b, (1, 1))
|
||||
b = self.Tile(b, (self.batch_size, 1))
|
||||
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
|
||||
deep_in = self.Concat((deep_in, b))
|
||||
deep_in = self.dense_layer_1(deep_in)
|
||||
deep_in = self.dense_layer_2(deep_in)
|
||||
deep_in = self.dense_layer_3(deep_in)
|
||||
deep_out = self.dense_layer_4(deep_in)
|
||||
deep_in = self.dense_layer_4(deep_in)
|
||||
deep_out = self.dense_layer_5(deep_in)
|
||||
out = linear_out + fm_out + deep_out
|
||||
return out, fm_id_weight, fm_id_embs
|
||||
return out, self.fm_w, self.embedding_table
|
||||
|
||||
|
||||
class NetWithLossClass(nn.Cell):
|
||||
|
@ -278,7 +293,7 @@ class TrainStepWrap(nn.Cell):
|
|||
"""
|
||||
TrainStepWrap definition
|
||||
"""
|
||||
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
|
||||
def __init__(self, network, lr, eps, loss_scale=1000.0):
|
||||
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_train()
|
||||
|
@ -288,11 +303,24 @@ class TrainStepWrap(nn.Cell):
|
|||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = loss_scale
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
weights = self.weights
|
||||
loss = self.network(batch_ids, batch_wts, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
||||
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ def test_deepfm():
|
|||
export_loss_value = 0.51
|
||||
print("loss_callback.loss:", loss_callback.loss)
|
||||
assert loss_callback.loss < export_loss_value
|
||||
export_per_step_time = 10.4
|
||||
export_per_step_time = 40.0
|
||||
print("time_callback:", time_callback.per_step_time)
|
||||
assert time_callback.per_step_time < export_per_step_time
|
||||
print("*******test case pass!********")
|
||||
|
|
Loading…
Reference in New Issue