!6474 add shape check

Merge pull request !6474 from lijiaqi/valiation
This commit is contained in:
mindspore-ci-bot 2020-09-18 16:16:02 +08:00 committed by Gitee
commit 05027e6c41
2 changed files with 7 additions and 4 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore package."""
""".. MindSpore package."""
from ._version_check import check_version_and_env_config
from . import common, train

View File

@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights.
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
is Tensor type, Tensor with shape :math:`()`.
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.
Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
name="scale_sense")
elif isinstance(scale_sense, Tensor):
self.scale_sense = Parameter(scale_sense, name='scale_sense')
if scale_sense.shape == (1,) or scale_sense.shape == ():
self.scale_sense = Parameter(scale_sense, name='scale_sense')
else:
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
else:
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))
@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
if self.scale_sense and isinstance(sens, Tensor):
self.scale_sense.set_data(sens)
else:
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))