forked from mindspore-Ecosystem/mindspore
!8803 Validate SampledSoftmaxLoss Args
From: @jonwe Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
044c7d183c
|
@ -284,7 +284,7 @@ class SampledSoftmaxLoss(_Loss):
|
|||
where a sampled class equals one of the target classes. Default is True.
|
||||
seed (int): Random seed for candidate sampling. Default: 0
|
||||
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
||||
If "none", do not perform reduction. Default: "None".
|
||||
If "none", do not perform reduction. Default: "none".
|
||||
|
||||
Inputs:
|
||||
- **weights** (Tensor) - Tensor of shape (C, dim).
|
||||
|
@ -311,7 +311,22 @@ class SampledSoftmaxLoss(_Loss):
|
|||
def __init__(self, num_sampled, num_classes, num_true=1,
|
||||
sampled_values=None, remove_accidental_hits=True, seed=0,
|
||||
reduction='none'):
|
||||
super(SampledSoftmaxLoss, self).__init__()
|
||||
super(SampledSoftmaxLoss, self).__init__(reduction)
|
||||
|
||||
if num_true < 1:
|
||||
raise ValueError(f"num_true {num_true} is less than 1.")
|
||||
if seed < 0:
|
||||
raise ValueError(f"seed {seed} is less than 0.")
|
||||
if num_sampled > num_classes:
|
||||
raise ValueError(f"num_sampled {num_sampled} is great than num_classes {num_classes}.")
|
||||
if num_true > num_classes:
|
||||
raise ValueError(f"num_true {num_true} is great than num_classes {num_classes}.")
|
||||
if sampled_values is not None:
|
||||
if not isinstance(sampled_values, (list, tuple)):
|
||||
raise TypeError(f"sampled_values {sampled_values} is not a list.")
|
||||
if len(sampled_values) != 3:
|
||||
raise ValueError(f"sampled_values size {len(sampled_values)} is not 3.")
|
||||
|
||||
self.num_sampled = num_sampled
|
||||
self.num_classes = num_classes
|
||||
self.num_true = num_true
|
||||
|
|
|
@ -131,6 +131,36 @@ def test_sampled_softmax_loss_none_sampler():
|
|||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
case_no_sampler()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sampledsoftmaxloss_reduction_invalid():
|
||||
# Check 'reduction'
|
||||
with pytest.raises(ValueError):
|
||||
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="invalid")
|
||||
|
||||
# reduction can be None, as defined in _Loss
|
||||
# with pytest.raises(ValueError):
|
||||
# nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction=None) #
|
||||
|
||||
# Check 'num_true'
|
||||
with pytest.raises(ValueError):
|
||||
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, num_true=0)
|
||||
|
||||
# Check 'sampled_values'
|
||||
with pytest.raises(ValueError):
|
||||
sampled_values_more_para = (Tensor(np.array([1])), Tensor(np.array([1])),
|
||||
Tensor(np.array([1])), Tensor(np.array([1])))
|
||||
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7,
|
||||
sampled_values=sampled_values_more_para)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
sampled_values_wrong_type = Tensor(np.array([1]))
|
||||
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7,
|
||||
sampled_values=sampled_values_wrong_type)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sampled_softmax_loss_assigned_sampler()
|
||||
|
|
Loading…
Reference in New Issue