Adjust the dense structure in the wide&deep multi-table
This commit is contained in:
parent
3259dafa7e
commit
e0a34dae3b
|
@ -32,6 +32,7 @@ def argparse_init():
|
|||
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
|
||||
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
|
||||
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
|
||||
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
|
||||
|
||||
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
|
||||
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
|
||||
|
@ -83,7 +84,6 @@ class WideDeepConfig():
|
|||
self.weight_bias_init = ['normal', 'normal']
|
||||
self.emb_init = 'normal'
|
||||
self.init_args = [-0.01, 0.01]
|
||||
self.dropout_flag = False
|
||||
self.l2_coef = args.l2_coef
|
||||
self.ftrl_lr = args.ftrl_lr
|
||||
self.adam_lr = args.adam_lr
|
||||
|
@ -93,3 +93,4 @@ class WideDeepConfig():
|
|||
self.eval_file_name = args.eval_file_name
|
||||
self.loss_file_name = args.loss_file_name
|
||||
self.ckpt_path = args.ckpt_path
|
||||
self.dropout_flag = bool(args.dropout_flag)
|
||||
|
|
|
@ -89,9 +89,11 @@ class DenseLayer(nn.Cell):
|
|||
output_dim,
|
||||
weight_bias_init,
|
||||
act_str,
|
||||
keep_prob=0.7,
|
||||
keep_prob=0.8,
|
||||
scale_coef=1.0,
|
||||
convert_dtype=True):
|
||||
use_activation=True,
|
||||
convert_dtype=True,
|
||||
drop_out=False):
|
||||
super(DenseLayer, self).__init__()
|
||||
weight_init, bias_init = weight_bias_init
|
||||
self.weight = init_method(weight_init, [input_dim, output_dim],
|
||||
|
@ -101,11 +103,13 @@ class DenseLayer(nn.Cell):
|
|||
self.matmul = P.MatMul(transpose_b=False)
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.cast = P.Cast()
|
||||
self.dropout = Dropout(keep_prob=0.8)
|
||||
self.dropout = Dropout(keep_prob=keep_prob)
|
||||
self.mul = P.Mul()
|
||||
self.realDiv = P.RealDiv()
|
||||
self.scale_coef = scale_coef
|
||||
self.use_activation = use_activation
|
||||
self.convert_dtype = convert_dtype
|
||||
self.drop_out = drop_out
|
||||
|
||||
def _init_activation(self, act_str):
|
||||
act_str = act_str.lower()
|
||||
|
@ -118,23 +122,26 @@ class DenseLayer(nn.Cell):
|
|||
return act_func
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
DenseLayer construct
|
||||
"""
|
||||
x = self.act_func(x)
|
||||
if self.training:
|
||||
'''
|
||||
Construct Dense layer
|
||||
'''
|
||||
if self.training and self.drop_out:
|
||||
x = self.dropout(x)
|
||||
x = self.mul(x, self.scale_coef)
|
||||
if self.convert_dtype:
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
bias = self.cast(self.bias, mstype.float16)
|
||||
wx = self.matmul(x, weight)
|
||||
wx = self.bias_add(wx, bias)
|
||||
if self.use_activation:
|
||||
wx = self.act_func(wx)
|
||||
wx = self.cast(wx, mstype.float32)
|
||||
else:
|
||||
wx = self.matmul(x, self.weight)
|
||||
wx = self.realDiv(wx, self.scale_coef)
|
||||
output = self.bias_add(wx, self.bias)
|
||||
return output
|
||||
wx = self.bias_add(wx, self.bias)
|
||||
if self.use_activation:
|
||||
wx = self.act_func(wx)
|
||||
return wx
|
||||
|
||||
|
||||
class WideDeepModel(nn.Cell):
|
||||
|
@ -211,33 +218,40 @@ class WideDeepModel(nn.Cell):
|
|||
self.all_dim_list[1],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
|
||||
self.all_dim_list[2],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
|
||||
self.all_dim_list[3],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
|
||||
self.all_dim_list[4],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True)
|
||||
self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
|
||||
self.all_dim_list[5],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True)
|
||||
|
||||
self.deep_predict = DenseLayer(self.all_dim_list[5],
|
||||
self.all_dim_list[6],
|
||||
self.weight_bias_init,
|
||||
self.deep_layer_act,
|
||||
convert_dtype=True)
|
||||
drop_out=config.dropout_flag,
|
||||
convert_dtype=True,
|
||||
use_activation=False)
|
||||
|
||||
self.gather_v2 = P.GatherV2()
|
||||
self.mul = P.Mul()
|
||||
|
|
|
@ -96,9 +96,10 @@ def train_and_eval(config):
|
|||
keep_checkpoint_max=10)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
||||
directory=config.ckpt_path, config=ckptconfig)
|
||||
|
||||
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
|
||||
callback, ckpoint_cb])
|
||||
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
|
||||
if int(get_rank()) == 0:
|
||||
callback_list.append(ckpoint_cb)
|
||||
model.train(epochs, ds_train, callbacks=callback_list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue