add functional method for UniformCandidateSampler

This commit is contained in:
louie5 2022-07-19 18:14:30 +08:00
parent c9937a553f
commit ae60a9fd95
8 changed files with 169 additions and 34 deletions

View File

@ -324,6 +324,7 @@ Tensor创建
mindspore.ops.standard_laplace
mindspore.ops.uniform
mindspore.ops.standard_normal
mindspore.ops.uniform_candidate_sampler
Array操作
^^^^^^^^^^^^^^^^

View File

@ -0,0 +1,32 @@
mindspore.ops.uniform_candidate_sampler
======================================
.. py:function:: mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)
使用均匀分布对一组类别进行采样。
此函数使用均匀分布从[0, range_max-1]中采样一组类sampled_candidates。如果 `unique` 为True则候选采样没有重复如果 `unique` 为False则有重复。
**参数:**
- **true_classes** (Tensor) - 输入Tensor目标类其shape为(batch_size, num_true)。
- **num_true** (int) - 每个训练样本的目标类数。
- **num_sampled** (int) - 随机采样的类数。sampled_candidates的shape将为 `num_sampled` 。如果 `unique` 为True`num_sampled` 必须小于或等于 `range_max`
- **unique** (bool) - 表示一个batch中的所有采样类是否唯一。
- **range_max** (int) - 可能的类数,该值必须是非负的。
- **seed** (int) - 随机种子该值必须是非负的。如果seed的值为0则seed的值将被随机生成的值替换。默认值0。
- **remove_accidental_hits** (bool) - 表示是否移除accidental hit。默认值False。
**返回:**
- **sampled_candidates** (Tensor) - 候选采样与目标类之间不存在联系其shape为(num_sampled, )。
- **true_expected_count** (Tensor) - 在每组目标类的采样分布下的预期计数。Shape为(batch_size, num_true)。
- **sampled_expected_count** (Tensor) - 每个候选采样分布下的预期计数。Shape为(num_sampled, )。
**异常:**
- **TypeError** - `num_true``num_sampled` 都不是int。
- **TypeError** - `uique``remo_acidental_hits` 都不是bool。
- **TypeError** - `range_max``seed` 都不是int。
- **TypeError** - `true_classes` 不是Tensor。

View File

@ -323,6 +323,7 @@ Randomly Generating Operators
mindspore.ops.standard_laplace
mindspore.ops.uniform
mindspore.ops.standard_normal
mindspore.ops.uniform_candidate_sampler
Array Operation
^^^^^^^^^^^^^^^

View File

@ -45,8 +45,14 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
batch_rank = GetValue<int64_t>(value_ptr);
}
const int64_t input_dim = 2;
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterEqual,
input_dim, op_name);
if (batch_rank > 0) {
// support vmap feature
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterThan,
input_dim, op_name);
} else {
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kEqual, input_dim,
op_name);
}
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });

View File

@ -321,6 +321,7 @@ from .random_func import (
uniform,
standard_normal,
random_gamma,
uniform_candidate_sampler,
)
__all__ = []

View File

@ -274,11 +274,63 @@ def standard_normal(shape, seed=0, seed2=0):
return standard_normal_op(shape)
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0,
remove_accidental_hits=False):
r"""
Uniform candidate sampler.
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
Args:
true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes, must be non-negative.
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
the seed will be replaced with a randomly generated value. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
Returns:
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
Shape: (num_sampled, ).
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
of true_classes. Shape: (batch_size, num_true).
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
each of sampled_candidates. Shape: (num_sampled, ).
Raises:
TypeError: If neither `num_true` nor `num_sampled` is an int.
TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
TypeError: If neither `range_max` nor `seed` is an int.
TypeError: If `true_classes` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))
>>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
>>> print(output1)
>>> print(output2)
>>> print(output3)
[0, 0, 3]
[[0.75], [0.75], [0.75], [0.75], [0.75]]
[0.75, 0.75, 0.75]
"""
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed=seed,
remove_accidental_hits=remove_accidental_hits)
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes)
return sampled_candidates, true_expected_count, sampled_expected_count
__all__ = [
'standard_laplace',
'random_categorical',
'uniform',
'standard_normal',
'random_gamma'
'random_gamma',
'uniform_candidate_sampler'
]
__all__.sort()

View File

@ -811,45 +811,19 @@ class UniformCandidateSampler(PrimitiveWithInfer):
Uniform candidate sampler.
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
Args:
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes, must be non-negative.
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
the seed will be replaced with a randomly generated value. Default: 0.
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
Inputs:
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
Outputs:
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
Shape: (num_sampled, ).
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
of true_classes. Shape: (batch_size, num_true).
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
each of sampled_candidates. Shape: (num_sampled, ).
Raises:
TypeError: If neither `num_true` nor `num_sampled` is an int.
TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
TypeError: If neither `range_max` nor `seed` is an int.
TypeError: If `true_classes` is not a Tensor.
Refer to :func:`mindspore.ops.uniform_candidate_sampler` for more detail.
Supported Platforms:
``GPU`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4)
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1)
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)))
>>> print(output1)
>>> print(output2)
>>> print(output3)
[1, 1, 3]
[0, 0, 3]
[[0.75], [0.75], [0.75], [0.75], [0.75]]
[0.75, 0.75, 0.75]
"""

View File

@ -18,6 +18,7 @@ import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.functional import vmap
@ -42,6 +43,11 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max):
return out1.shape, out2.shape, out3.shape
def uniform_candidate_sampler_functional(x, num_true, num_sample, unique, range_max):
out1, out2, out3 = F.uniform_candidate_sampler(Tensor(x.astype(np.int32)), num_true, num_sample, unique, range_max)
return out1.shape, out2.shape, out3.shape
def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max):
uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true,
num_sampled,
@ -210,7 +216,7 @@ def test_uniform_candidate_sampler_large_random_int64_input():
Description: The input data is random large with type int64 for UniformCandidateSampler
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63),
63, 10, False, 12)
expected_1 = (10,)
@ -373,3 +379,65 @@ def test_uniform_candidate_sampler_vmap2_unique_1_true():
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_uniform_candidate_sampler_functional_unique_1_true():
"""
Feature: Functional interface of UniformCandidateSampler CPU TEST.
Description: The unique is true for uniform_candidate_sampler
Expectation: The shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1], [3], [4], [6], [3]]),
1, 3, True, 4)
expected_1 = (3,)
expected_2 = (5, 1)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_uniform_candidate_sampler_functional_not_unique_2_true():
"""
Feature: Functional interface of UniformCandidateSampler CPU TEST.
Description: The unique is false and num_true is 2 for uniform_candidate_sampler
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1, 2], [3, 2],
[4, 2], [6, 2],
[3, 2]]),
2, 3, False, 4)
expected_1 = (3,)
expected_2 = (5, 2)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_uniform_candidate_sampler_functional_large_random():
"""
Feature: Functional interface of UniformCandidateSampler CPU TEST.
Description: The input data is random large with type int32 for uniform_candidate_sampler
Expectation: The shape of output are the expected values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.arange(2142).reshape(34, 63),
63, 10, False, 12)
expected_1 = (10,)
expected_2 = (34, 63)
expected_3 = (10,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)