forked from mindspore-Ecosystem/mindspore
Loss to LossBase
This commit is contained in:
parent
11694db238
commit
30d4d5eedb
|
@ -24,7 +24,7 @@ import mindspore.nn as nn
|
|||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn.loss.loss import Loss
|
||||
from mindspore.nn.loss.loss import LossBase
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -34,7 +34,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|||
from src.seq2seq import Encoder, Decoder
|
||||
|
||||
|
||||
class NLLLoss(Loss):
|
||||
class NLLLoss(LossBase):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
|
|
Loading…
Reference in New Issue