forked from OSSInnovation/mindspore
commit
05027e6c41
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MindSpore package."""
|
||||
""".. MindSpore package."""
|
||||
|
||||
from ._version_check import check_version_and_env_config
|
||||
from . import common, train
|
||||
|
|
|
@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
|
||||
is Tensor type, Tensor with shape :math:`()`.
|
||||
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
|
||||
|
||||
Inputs:
|
||||
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
|
||||
|
@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
|
||||
name="scale_sense")
|
||||
elif isinstance(scale_sense, Tensor):
|
||||
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||
if scale_sense.shape == (1,) or scale_sense.shape == ():
|
||||
self.scale_sense = Parameter(scale_sense, name='scale_sense')
|
||||
else:
|
||||
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
|
||||
else:
|
||||
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
|
||||
|
||||
|
@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|||
if self.scale_sense and isinstance(sens, Tensor):
|
||||
self.scale_sense.set_data(sens)
|
||||
else:
|
||||
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
|
||||
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))
|
||||
|
|
Loading…
Reference in New Issue