From df48941c3b495b404ef3a9707087fee42c2cc37d Mon Sep 17 00:00:00 2001 From: lichenever Date: Wed, 15 Jul 2020 20:37:20 +0800 Subject: [PATCH] fix model_zoo --- model_zoo/wide_and_deep/src/wide_and_deep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_zoo/wide_and_deep/src/wide_and_deep.py b/model_zoo/wide_and_deep/src/wide_and_deep.py index 5c04687fdce..048bf3c66d5 100644 --- a/model_zoo/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/wide_and_deep/src/wide_and_deep.py @@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell): self.deep_layer_act, use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) - self.embeddinglookup = nn.EmbeddingLookup() + self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') self.mul = P.Mul() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reshape = P.Reshape() @@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell): """ mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer - wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr, 0) + wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) wx = self.mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer - deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr, 0) + deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) vx = self.mul(deep_id_embs, mask) deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in)