!8803 Validate SampledSoftmaxLoss Args

From: @jonwe
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2020-11-20 10:58:39 +08:00 committed by Gitee
commit 044c7d183c
2 changed files with 47 additions and 2 deletions

View File

@ -284,7 +284,7 @@ class SampledSoftmaxLoss(_Loss):
where a sampled class equals one of the target classes. Default is True.
seed (int): Random seed for candidate sampling. Default: 0
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default: "None".
If "none", do not perform reduction. Default: "none".
Inputs:
- **weights** (Tensor) - Tensor of shape (C, dim).
@ -311,7 +311,22 @@ class SampledSoftmaxLoss(_Loss):
def __init__(self, num_sampled, num_classes, num_true=1,
sampled_values=None, remove_accidental_hits=True, seed=0,
reduction='none'):
super(SampledSoftmaxLoss, self).__init__()
super(SampledSoftmaxLoss, self).__init__(reduction)
if num_true < 1:
raise ValueError(f"num_true {num_true} is less than 1.")
if seed < 0:
raise ValueError(f"seed {seed} is less than 0.")
if num_sampled > num_classes:
raise ValueError(f"num_sampled {num_sampled} is great than num_classes {num_classes}.")
if num_true > num_classes:
raise ValueError(f"num_true {num_true} is great than num_classes {num_classes}.")
if sampled_values is not None:
if not isinstance(sampled_values, (list, tuple)):
raise TypeError(f"sampled_values {sampled_values} is not a list.")
if len(sampled_values) != 3:
raise ValueError(f"sampled_values size {len(sampled_values)} is not 3.")
self.num_sampled = num_sampled
self.num_classes = num_classes
self.num_true = num_true

View File

@ -131,6 +131,36 @@ def test_sampled_softmax_loss_none_sampler():
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
case_no_sampler()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sampledsoftmaxloss_reduction_invalid():
# Check 'reduction'
with pytest.raises(ValueError):
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="")
with pytest.raises(ValueError):
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="invalid")
# reduction can be None, as defined in _Loss
# with pytest.raises(ValueError):
# nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction=None) #
# Check 'num_true'
with pytest.raises(ValueError):
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, num_true=0)
# Check 'sampled_values'
with pytest.raises(ValueError):
sampled_values_more_para = (Tensor(np.array([1])), Tensor(np.array([1])),
Tensor(np.array([1])), Tensor(np.array([1])))
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7,
sampled_values=sampled_values_more_para)
with pytest.raises(TypeError):
sampled_values_wrong_type = Tensor(np.array([1]))
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7,
sampled_values=sampled_values_wrong_type)
if __name__ == "__main__":
test_sampled_softmax_loss_assigned_sampler()