forked from mindspore-Ecosystem/mindspore
!21679 add fat-deepffm master
Merge pull request !21679 from four_WW/dffm_master
This commit is contained in:
commit
e25118eb66
|
@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
|
|||
|
||||
import mindspore.ops as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
from mindspore import Parameter, ParameterTuple
|
||||
from mindspore import Tensor
|
||||
|
@ -350,8 +351,7 @@ class TrainStepWrap(nn.Cell):
|
|||
grads = self.grad(self.network, weights)(cats_vals, num_vals, label, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
self.optimizer(grads)
|
||||
return loss
|
||||
return F.depend(loss, self.optimizer(grads))
|
||||
|
||||
|
||||
class ModelBuilder:
|
||||
|
|
|
@ -176,8 +176,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat
|
|||
dense_list = []
|
||||
label_list = []
|
||||
|
||||
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 1)
|
||||
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 1)
|
||||
writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
|
||||
writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
|
||||
|
||||
schema = {"label": {"type": "float32", "shape": [-1]}, "num_vals": {"type": "float32", "shape": [-1]},
|
||||
"cats_vals": {"type": "int32", "shape": [-1]}}
|
||||
|
|
Loading…
Reference in New Issue