forked from mindspore-Ecosystem/mindspore
modify example
This commit is contained in:
parent
18fb04b7b0
commit
006879fbb8
|
@ -162,9 +162,9 @@ class TrainOneStepCell(Cell):
|
|||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
>>>
|
||||
>>> #2) Using user-defined WithLossCell
|
||||
>>>class MyWithLossCell(nn.cell):
|
||||
>>>class MyWithLossCell(nn.Cell):
|
||||
>>> def __init__(self, backbone, loss_fn):
|
||||
>>> super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
>>> super(MyWithLossCell, self).__init__(auto_prefix=False)
|
||||
>>> self._backbone = backbone
|
||||
>>> self._loss_fn = loss_fn
|
||||
>>>
|
||||
|
@ -172,6 +172,10 @@ class TrainOneStepCell(Cell):
|
|||
>>> out = self._backbone(x, y)
|
||||
>>> return self._loss_fn(out, label)
|
||||
>>>
|
||||
>>> @property
|
||||
>>> def backbone_network(self):
|
||||
>>> return self._backbone
|
||||
>>>
|
||||
>>> loss_net = MyWithLossCell(net, loss_fn)
|
||||
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue