!21679 add fat-deepffm master

Merge pull request !21679 from four_WW/dffm_master
This commit is contained in:
i-robot 2021-08-12 08:15:40 +00:00 committed by Gitee
commit e25118eb66
2 changed files with 4 additions and 4 deletions

View File

@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
import mindspore.ops as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import Parameter, ParameterTuple
from mindspore import Tensor
@ -350,8 +351,7 @@ class TrainStepWrap(nn.Cell):
grads = self.grad(self.network, weights)(cats_vals, num_vals, label, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
return F.depend(loss, self.optimizer(grads))
class ModelBuilder:

View File

@ -176,8 +176,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat
dense_list = []
label_list = []
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 1)
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 1)
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
schema = {"label": {"type": "float32", "shape": [-1]}, "num_vals": {"type": "float32", "shape": [-1]},
"cats_vals": {"type": "int32", "shape": [-1]}}