From 697550f4be03c9eba1cbc9c922ee36e6096effd0 Mon Sep 17 00:00:00 2001 From: "3236961631@qq.com" <3236961631@qq.com> Date: Wed, 11 Aug 2021 20:54:09 +0800 Subject: [PATCH] fat_deepffm commit --- model_zoo/research/recommend/Fat-DeepFFM/src/fat_deepffm.py | 4 ++-- .../research/recommend/Fat-DeepFFM/src/preprocess_data.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model_zoo/research/recommend/Fat-DeepFFM/src/fat_deepffm.py b/model_zoo/research/recommend/Fat-DeepFFM/src/fat_deepffm.py index 715c02ff1bf..3de30f1a3b3 100644 --- a/model_zoo/research/recommend/Fat-DeepFFM/src/fat_deepffm.py +++ b/model_zoo/research/recommend/Fat-DeepFFM/src/fat_deepffm.py @@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer import mindspore.ops as P from mindspore.ops import composite as C +from mindspore.ops import functional as F from mindspore import Parameter, ParameterTuple from mindspore import Tensor @@ -350,8 +351,7 @@ class TrainStepWrap(nn.Cell): grads = self.grad(self.network, weights)(cats_vals, num_vals, label, sens) if self.reducer_flag: grads = self.grad_reducer(grads) - self.optimizer(grads) - return loss + return F.depend(loss, self.optimizer(grads)) class ModelBuilder: diff --git a/model_zoo/research/recommend/Fat-DeepFFM/src/preprocess_data.py b/model_zoo/research/recommend/Fat-DeepFFM/src/preprocess_data.py index 8bed00d339b..70ef11d76be 100644 --- a/model_zoo/research/recommend/Fat-DeepFFM/src/preprocess_data.py +++ b/model_zoo/research/recommend/Fat-DeepFFM/src/preprocess_data.py @@ -176,8 +176,8 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, recommendat dense_list = [] label_list = [] - writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 1) - writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 1) + writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21) + writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3) schema = {"label": {"type": "float32", "shape": [-1]}, "num_vals": {"type": "float32", "shape": [-1]}, "cats_vals": {"type": "int32", "shape": [-1]}}