add functional method for UniformCandidateSampler
This commit is contained in:
parent
c9937a553f
commit
ae60a9fd95
|
@ -324,6 +324,7 @@ Tensor创建
|
|||
mindspore.ops.standard_laplace
|
||||
mindspore.ops.uniform
|
||||
mindspore.ops.standard_normal
|
||||
mindspore.ops.uniform_candidate_sampler
|
||||
|
||||
Array操作
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
mindspore.ops.uniform_candidate_sampler
|
||||
======================================
|
||||
|
||||
.. py:function:: mindspore.ops.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False)
|
||||
|
||||
使用均匀分布对一组类别进行采样。
|
||||
|
||||
此函数使用均匀分布从[0, range_max-1]中采样一组类(sampled_candidates)。如果 `unique` 为True,则候选采样没有重复;如果 `unique` 为False,则有重复。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **true_classes** (Tensor) - 输入Tensor,目标类,其shape为(batch_size, num_true)。
|
||||
- **num_true** (int) - 每个训练样本的目标类数。
|
||||
- **num_sampled** (int) - 随机采样的类数。sampled_candidates的shape将为 `num_sampled` 。如果 `unique` 为True,则 `num_sampled` 必须小于或等于 `range_max` 。
|
||||
- **unique** (bool) - 表示一个batch中的所有采样类是否唯一。
|
||||
- **range_max** (int) - 可能的类数,该值必须是非负的。
|
||||
- **seed** (int) - 随机种子,该值必须是非负的。如果seed的值为0,则seed的值将被随机生成的值替换。默认值:0。
|
||||
- **remove_accidental_hits** (bool) - 表示是否移除accidental hit。默认值:False。
|
||||
|
||||
**返回:**
|
||||
|
||||
- **sampled_candidates** (Tensor) - 候选采样与目标类之间不存在联系,其shape为(num_sampled, )。
|
||||
- **true_expected_count** (Tensor) - 在每组目标类的采样分布下的预期计数。Shape为(batch_size, num_true)。
|
||||
- **sampled_expected_count** (Tensor) - 每个候选采样分布下的预期计数。Shape为(num_sampled, )。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `num_true` 和 `num_sampled` 都不是int。
|
||||
- **TypeError** - `uique` 和 `remo_acidental_hits` 都不是bool。
|
||||
- **TypeError** - `range_max` 和 `seed` 都不是int。
|
||||
- **TypeError** - `true_classes` 不是Tensor。
|
||||
|
|
@ -323,6 +323,7 @@ Randomly Generating Operators
|
|||
mindspore.ops.standard_laplace
|
||||
mindspore.ops.uniform
|
||||
mindspore.ops.standard_normal
|
||||
mindspore.ops.uniform_candidate_sampler
|
||||
|
||||
Array Operation
|
||||
^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -45,8 +45,14 @@ abstract::TupleShapePtr UCSInferShape(const PrimitivePtr &primitive, const std::
|
|||
batch_rank = GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
const int64_t input_dim = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterEqual,
|
||||
input_dim, op_name);
|
||||
if (batch_rank > 0) {
|
||||
// support vmap feature
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kGreaterThan,
|
||||
input_dim, op_name);
|
||||
} else {
|
||||
(void)CheckAndConvertUtils::CheckInteger("dimension of input", SizeToLong(input_shape.size()), kEqual, input_dim,
|
||||
op_name);
|
||||
}
|
||||
|
||||
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
|
||||
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
||||
|
|
|
@ -321,6 +321,7 @@ from .random_func import (
|
|||
uniform,
|
||||
standard_normal,
|
||||
random_gamma,
|
||||
uniform_candidate_sampler,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -274,11 +274,63 @@ def standard_normal(shape, seed=0, seed2=0):
|
|||
return standard_normal_op(shape)
|
||||
|
||||
|
||||
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0,
|
||||
remove_accidental_hits=False):
|
||||
r"""
|
||||
Uniform candidate sampler.
|
||||
|
||||
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
|
||||
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
|
||||
|
||||
Args:
|
||||
true_classes (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
|
||||
num_true (int): The number of target classes in each training example.
|
||||
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
|
||||
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
|
||||
unique (bool): Whether all sampled classes in a batch are unique.
|
||||
range_max (int): The number of possible classes, must be non-negative.
|
||||
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
|
||||
the seed will be replaced with a randomly generated value. Default: 0.
|
||||
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
|
||||
|
||||
Returns:
|
||||
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
|
||||
Shape: (num_sampled, ).
|
||||
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
|
||||
of true_classes. Shape: (batch_size, num_true).
|
||||
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
|
||||
each of sampled_candidates. Shape: (num_sampled, ).
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `num_true` nor `num_sampled` is an int.
|
||||
TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
|
||||
TypeError: If neither `range_max` nor `seed` is an int.
|
||||
TypeError: If `true_classes` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32))
|
||||
>>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
|
||||
>>> print(output1)
|
||||
>>> print(output2)
|
||||
>>> print(output3)
|
||||
[0, 0, 3]
|
||||
[[0.75], [0.75], [0.75], [0.75], [0.75]]
|
||||
[0.75, 0.75, 0.75]
|
||||
"""
|
||||
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed=seed,
|
||||
remove_accidental_hits=remove_accidental_hits)
|
||||
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes)
|
||||
return sampled_candidates, true_expected_count, sampled_expected_count
|
||||
|
||||
__all__ = [
|
||||
'standard_laplace',
|
||||
'random_categorical',
|
||||
'uniform',
|
||||
'standard_normal',
|
||||
'random_gamma'
|
||||
'random_gamma',
|
||||
'uniform_candidate_sampler'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -811,45 +811,19 @@ class UniformCandidateSampler(PrimitiveWithInfer):
|
|||
Uniform candidate sampler.
|
||||
|
||||
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
|
||||
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
|
||||
|
||||
Args:
|
||||
num_true (int): The number of target classes in each training example.
|
||||
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
|
||||
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
|
||||
unique (bool): Whether all sampled classes in a batch are unique.
|
||||
range_max (int): The number of possible classes, must be non-negative.
|
||||
seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
|
||||
the seed will be replaced with a randomly generated value. Default: 0.
|
||||
remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
|
||||
|
||||
Outputs:
|
||||
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
|
||||
Shape: (num_sampled, ).
|
||||
- **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
|
||||
of true_classes. Shape: (batch_size, num_true).
|
||||
- **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
|
||||
each of sampled_candidates. Shape: (num_sampled, ).
|
||||
|
||||
Raises:
|
||||
TypeError: If neither `num_true` nor `num_sampled` is an int.
|
||||
TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
|
||||
TypeError: If neither `range_max` nor `seed` is an int.
|
||||
TypeError: If `true_classes` is not a Tensor.
|
||||
Refer to :func:`mindspore.ops.uniform_candidate_sampler` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4)
|
||||
>>> sampler = ops.UniformCandidateSampler(1, 3, False, 4, 1)
|
||||
>>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)))
|
||||
>>> print(output1)
|
||||
>>> print(output2)
|
||||
>>> print(output3)
|
||||
[1, 1, 3]
|
||||
[0, 0, 3]
|
||||
[[0.75], [0.75], [0.75], [0.75], [0.75]]
|
||||
[0.75, 0.75, 0.75]
|
||||
"""
|
||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops.functional import vmap
|
||||
|
@ -42,6 +43,11 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max):
|
|||
return out1.shape, out2.shape, out3.shape
|
||||
|
||||
|
||||
def uniform_candidate_sampler_functional(x, num_true, num_sample, unique, range_max):
|
||||
out1, out2, out3 = F.uniform_candidate_sampler(Tensor(x.astype(np.int32)), num_true, num_sample, unique, range_max)
|
||||
return out1.shape, out2.shape, out3.shape
|
||||
|
||||
|
||||
def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max):
|
||||
uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true,
|
||||
num_sampled,
|
||||
|
@ -210,7 +216,7 @@ def test_uniform_candidate_sampler_large_random_int64_input():
|
|||
Description: The input data is random large with type int64 for UniformCandidateSampler
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63),
|
||||
63, 10, False, 12)
|
||||
expected_1 = (10,)
|
||||
|
@ -373,3 +379,65 @@ def test_uniform_candidate_sampler_vmap2_unique_1_true():
|
|||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_candidate_sampler_functional_unique_1_true():
|
||||
"""
|
||||
Feature: Functional interface of UniformCandidateSampler CPU TEST.
|
||||
Description: The unique is true for uniform_candidate_sampler
|
||||
Expectation: The shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1], [3], [4], [6], [3]]),
|
||||
1, 3, True, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 1)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_candidate_sampler_functional_not_unique_2_true():
|
||||
"""
|
||||
Feature: Functional interface of UniformCandidateSampler CPU TEST.
|
||||
Description: The unique is false and num_true is 2 for uniform_candidate_sampler
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.array([[1, 2], [3, 2],
|
||||
[4, 2], [6, 2],
|
||||
[3, 2]]),
|
||||
2, 3, False, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 2)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_candidate_sampler_functional_large_random():
|
||||
"""
|
||||
Feature: Functional interface of UniformCandidateSampler CPU TEST.
|
||||
Description: The input data is random large with type int32 for uniform_candidate_sampler
|
||||
Expectation: The shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
ms1, ms2, ms3 = uniform_candidate_sampler_functional(np.arange(2142).reshape(34, 63),
|
||||
63, 10, False, 12)
|
||||
expected_1 = (10,)
|
||||
expected_2 = (34, 63)
|
||||
expected_3 = (10,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
|
Loading…
Reference in New Issue