fix model zoo script error

This commit is contained in:
chujinjin 2021-08-10 09:27:28 +08:00
parent 846393b176
commit b25110eb2c
2 changed files with 2 additions and 2 deletions

View File

@ -109,7 +109,7 @@ class AttnDecoderRNN(nn.Cell):
output = self.relu(output)
gru_hidden = self.squeeze1(hidden)
output, hidden, _, _, _, _ = self.gru(output, gru_hidden)
output, hidden = self.gru(output, gru_hidden)
output = self.squeeze1(output)
output = self.log_softmax(self.out(output))

View File

@ -346,7 +346,7 @@ class PredictWithSigmoid(nn.Cell):
self.sigmoid = P.Sigmoid()
def construct(self, batch_ids, batch_wts, labels):
logits, _, _, = self.network(batch_ids, batch_wts)
logits, _, _, _, _, = self.network(batch_ids, batch_wts)
pred_probs = self.sigmoid(logits)
return logits, pred_probs, labels