Loss to LossBase

This commit is contained in:
panfengfeng 2021-07-29 09:04:02 +08:00
parent 11694db238
commit 30d4d5eedb
1 changed files with 2 additions and 2 deletions

View File

@ -24,7 +24,7 @@ import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.nn.loss.loss import Loss
from mindspore.nn.loss.loss import LossBase
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.context import ParallelMode
@ -34,7 +34,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.seq2seq import Encoder, Decoder
class NLLLoss(Loss):
class NLLLoss(LossBase):
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()