diff --git a/docs/api/api_python/mindspore.ops.functional.rst b/docs/api/api_python/mindspore.ops.functional.rst index fd2c46aef93..f0fdb9ebc20 100644 --- a/docs/api/api_python/mindspore.ops.functional.rst +++ b/docs/api/api_python/mindspore.ops.functional.rst @@ -324,6 +324,7 @@ Tensor创建 mindspore.ops.standard_laplace mindspore.ops.uniform mindspore.ops.standard_normal + mindspore.ops.uniform_candidate_sampler Array操作 ^^^^^^^^^^^^^^^^ diff --git a/docs/api/api_python/ops/mindspore.ops.func_uniform_candidate_sampler.rst b/docs/api/api_python/ops/mindspore.ops.func_uniform_candidate_sampler.rst new file mode 100644 index 00000000000..1f7d1b49ff0 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_uniform_candidate_sampler.rst @@ -0,0 +1,32 @@ +mindspore.ops.uniform_candidate_sampler +====================================== + +.. py:function:: mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False) + + 使用均匀分布对一组类别进行采样。 + + 此函数使用均匀分布从[0, range_max-1]中采样一组类(sampled_candidates)。如果 `unique` 为True,则候选采样没有重复;如果 `unique` 为False,则有重复。 + + **参数:** + + - **true_classes** (Tensor) - 输入Tensor,目标类,其shape为(batch_size, num_true)。 + - **num_true** (int) - 每个训练样本的目标类数。 + - **num_sampled** (int) - 随机采样的类数。sampled_candidates的shape将为 `num_sampled` 。如果 `unique` 为True,则 `num_sampled` 必须小于或等于 `range_max` 。 + - **unique** (bool) - 表示一个batch中的所有采样类是否唯一。 + - **range_max** (int) - 可能的类数,该值必须是非负的。 + - **seed** (int) - 随机种子,该值必须是非负的。如果seed的值为0,则seed的值将被随机生成的值替换。默认值:0。 + - **remove_accidental_hits** (bool) - 表示是否移除accidental hit。默认值:False。 + + **返回:** + + - **sampled_candidates** (Tensor) - 候选采样与目标类之间不存在联系,其shape为(num_sampled, )。 + - **true_expected_count** (Tensor) - 在每组目标类的采样分布下的预期计数。Shape为(batch_size, num_true)。 + - **sampled_expected_count** (Tensor) - 每个候选采样分布下的预期计数。Shape为(num_sampled, )。 + + **异常:** + + - **TypeError** - `num_true` 和 `num_sampled` 都不是int。 + - **TypeError** - `uique` 和 `remo_acidental_hits` 都不是bool。 + - **TypeError** - `range_max` 和 `seed` 都不是int。 + - **TypeError** - `true_classes` 不是Tensor。 + \ No newline at end of file diff --git a/docs/api/api_python_en/mindspore.ops.functional.rst b/docs/api/api_python_en/mindspore.ops.functional.rst index f8c4b5d8d69..d052c49b7a8 100644 --- a/docs/api/api_python_en/mindspore.ops.functional.rst +++ b/docs/api/api_python_en/mindspore.ops.functional.rst @@ -323,6 +323,7 @@ Randomly Generating Operators mindspore.ops.standard_laplace mindspore.ops.uniform mindspore.ops.standard_normal + mindspore.ops.uniform_candidate_sampler Array Operation ^^^^^^^^^^^^^^^ diff --git a/mindspore/core/ops/uniform_candidate_sampler.cc b/mindspore/core/ops/uniform_candidate_sampler.cc index feb6fd1b872..98a34678be9 100644 --- a/mindspore/core/ops/uniform_candidate_sampler.cc +++ b/mindspore/core/ops/uniform_candidate_sampler.cc @@ -45,8 +45,14 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std:: batch_rank = GetValue(value_ptr); } const int64_t input_dim = 2; - (void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterEqual, - input_dim, op_name); + if (batch_rank > 0) { + // support vmap feature + (void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterThan, + input_dim, op_name); + } else { + (void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kEqual, input_dim, + op_name); + } bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(), [](int64_t value) { return value != abstract::Shape::SHP_ANY; }); diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index f4a9da65949..faf3b585a36 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -321,6 +321,7 @@ from .random_func import ( uniform, standard_normal, random_gamma, + uniform_candidate_sampler, ) __all__ = [] diff --git a/mindspore/python/mindspore/ops/function/random_func.py b/mindspore/python/mindspore/ops/function/random_func.py index 5bb8b3a49bf..bd9557cd155 100644 --- a/mindspore/python/mindspore/ops/function/random_func.py +++ b/mindspore/python/mindspore/ops/function/random_func.py @@ -274,11 +274,63 @@ def standard_normal(shape, seed=0, seed2=0): return standard_normal_op(shape) +def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, + remove_accidental_hits=False): + r""" + Uniform candidate sampler. + + This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. + If unique=True, candidates are drawn without replacement, else unique=False with replacement. + + Args: + true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true). + num_true (int): The number of target classes in each training example. + num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape + of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. + unique (bool): Whether all sampled classes in a batch are unique. + range_max (int): The number of possible classes, must be non-negative. + seed (int): Used for random number generation, must be non-negative. If seed has a value of 0, + the seed will be replaced with a randomly generated value. Default: 0. + remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. + + Returns: + - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes. + Shape: (num_sampled, ). + - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each + of true_classes. Shape: (batch_size, num_true). + - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of + each of sampled_candidates. Shape: (num_sampled, ). + + Raises: + TypeError: If neither `num_true` nor `num_sampled` is an int. + TypeError: If neither `unique` nor `remove_accidental_hits` is a bool. + TypeError: If neither `range_max` nor `seed` is an int. + TypeError: If `true_classes` is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)) + >>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1) + >>> print(output1) + >>> print(output2) + >>> print(output3) + [0, 0, 3] + [[0.75], [0.75], [0.75], [0.75], [0.75]] + [0.75, 0.75, 0.75] + """ + sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed=seed, + remove_accidental_hits=remove_accidental_hits) + sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes) + return sampled_candidates, true_expected_count, sampled_expected_count + __all__ = [ 'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', - 'random_gamma' + 'random_gamma', + 'uniform_candidate_sampler' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/operations/random_ops.py b/mindspore/python/mindspore/ops/operations/random_ops.py index 27cffad85d2..21eed5d66cb 100644 --- a/mindspore/python/mindspore/ops/operations/random_ops.py +++ b/mindspore/python/mindspore/ops/operations/random_ops.py @@ -811,45 +811,19 @@ class UniformCandidateSampler(PrimitiveWithInfer): Uniform candidate sampler. This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. - If unique=True, candidates are drawn without replacement, else unique=False with replacement. - Args: - num_true (int): The number of target classes in each training example. - num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape - of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. - unique (bool): Whether all sampled classes in a batch are unique. - range_max (int): The number of possible classes, must be non-negative. - seed (int): Used for random number generation, must be non-negative. If seed has a value of 0, - the seed will be replaced with a randomly generated value. Default: 0. - remove_accidental_hits (bool): Whether accidental hit is removed. Default: False. - - Inputs: - - **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true). - - Outputs: - - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes. - Shape: (num_sampled, ). - - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each - of true_classes. Shape: (batch_size, num_true). - - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of - each of sampled_candidates. Shape: (num_sampled, ). - - Raises: - TypeError: If neither `num_true` nor `num_sampled` is an int. - TypeError: If neither `unique` nor `remove_accidental_hits` is a bool. - TypeError: If neither `range_max` nor `seed` is an int. - TypeError: If `true_classes` is not a Tensor. + Refer to :func:`mindspore.ops.uniform_candidate_sampler` for more detail. Supported Platforms: - ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> sampler = ops.UniformCandidateSampler(1, 3, False, 4) + >>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1) >>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))) >>> print(output1) >>> print(output2) >>> print(output3) - [1, 1, 3] + [0, 0, 3] [[0.75], [0.75], [0.75], [0.75], [0.75]] [0.75, 0.75, 0.75] """ diff --git a/tests/st/ops/cpu/test_uniform_candidate_sampler_op.py b/tests/st/ops/cpu/test_uniform_candidate_sampler_op.py index 37890e0606d..f91ab20661c 100644 --- a/tests/st/ops/cpu/test_uniform_candidate_sampler_op.py +++ b/tests/st/ops/cpu/test_uniform_candidate_sampler_op.py @@ -18,6 +18,7 @@ import pytest from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F import mindspore.nn as nn import mindspore.context as context from mindspore.ops.functional import vmap @@ -42,6 +43,11 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max): return out1.shape, out2.shape, out3.shape +def uniform_candidate_sampler_functional(x, num_true, num_sample, unique, range_max): + out1, out2, out3 = F.uniform_candidate_sampler(Tensor(x.astype(np.int32)), num_true, num_sample, unique, range_max) + return out1.shape, out2.shape, out3.shape + + def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max): uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true, num_sampled, @@ -210,7 +216,7 @@ def test_uniform_candidate_sampler_large_random_int64_input(): Description: The input data is random large with type int64 for UniformCandidateSampler Expectation: The value and shape of output are the expected values. """ - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63), 63, 10, False, 12) expected_1 = (10,) @@ -373,3 +379,65 @@ def test_uniform_candidate_sampler_vmap2_unique_1_true(): np.testing.assert_array_equal(ms1, expected_1) np.testing.assert_array_equal(ms2, expected_2) np.testing.assert_array_equal(ms3, expected_3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_uniform_candidate_sampler_functional_unique_1_true(): + """ + Feature: Functional interface of UniformCandidateSampler CPU TEST. + Description: The unique is true for uniform_candidate_sampler + Expectation: The shape of output are the expected values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1], [3], [4], [6], [3]]), + 1, 3, True, 4) + expected_1 = (3,) + expected_2 = (5, 1) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_uniform_candidate_sampler_functional_not_unique_2_true(): + """ + Feature: Functional interface of UniformCandidateSampler CPU TEST. + Description: The unique is false and num_true is 2 for uniform_candidate_sampler + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1, 2], [3, 2], + [4, 2], [6, 2], + [3, 2]]), + 2, 3, False, 4) + expected_1 = (3,) + expected_2 = (5, 2) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_uniform_candidate_sampler_functional_large_random(): + """ + Feature: Functional interface of UniformCandidateSampler CPU TEST. + Description: The input data is random large with type int32 for uniform_candidate_sampler + Expectation: The shape of output are the expected values. + """ + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.arange(2142).reshape(34, 63), + 63, 10, False, 12) + expected_1 = (10,) + expected_2 = (34, 63) + expected_3 = (10,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3)