fat_deepffm commit

This commit is contained in:
3236961631@qq.com 2021-08-11 20:54:09 +08:00
parent 50d54a7482
commit 697550f4be
2 changed files with 4 additions and 4 deletions

View File

@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
import mindspore.ops as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import Parameter, ParameterTuple
from mindspore import Tensor
@ -350,8 +351,7 @@ class TrainStepWrap(nn.Cell):
grads = self.grad(self.network, weights)(cats_vals, num_vals, label, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
self.optimizer(grads)
return loss
return F.depend(loss, self.optimizer(grads))
class ModelBuilder:

View File

@ -176,8 +176,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat
dense_list = []
label_list = []
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 1)
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 1)
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
schema = {"label": {"type": "float32", "shape": [-1]}, "num_vals": {"type": "float32", "shape": [-1]},
"cats_vals": {"type": "int32", "shape": [-1]}}