!44521 [MS]solve the review comments for log_uniform_candidate_sampler primitive

Merge pull request !44521 from zhaizhiqiang/master
This commit is contained in:
i-robot 2022-10-26 02:01:22 +00:00 committed by Gitee
commit 9b86178885
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 62 additions and 86 deletions

View File

@ -93,6 +93,7 @@ mindspore.ops.function
:template: classtemplate.rst
mindspore.ops.grid_sample
mindspore.ops.log_uniform_candidate_sampler
mindspore.ops.uniform_candidate_sampler
距离函数
@ -105,17 +106,6 @@ mindspore.ops.function
mindspore.ops.cdist
采样函数
^^^^^^^^^^
.. mscnplatformautosummary::
:toctree: ops
:nosignatures:
:template: classtemplate.rst
mindspore.ops.log_uniform_candidate_sampler
数学运算函数
^^^^^^^^^^^^^^^^^

View File

@ -94,6 +94,7 @@ Sampling Functions
:template: classtemplate.rst
mindspore.ops.grid_sample
mindspore.ops.log_uniform_candidate_sampler
mindspore.ops.uniform_candidate_sampler
Distance Functions

View File

@ -17,6 +17,7 @@
#include <random>
#include "securec/include/securec.h"
#include "mindspore/ccsrc/plugin/device/cpu/kernel/random_util.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace kernel {
@ -69,7 +70,10 @@ double Uint64ToDouble(uint32_t x0, uint32_t x1) {
float Uint32ToFloat(uint32_t x) {
uint32_t val = (127 << 23) | (x & 0x7fffffu);
float f;
memcpy(&f, &val, sizeof(val));
auto ret = memcpy_s(&f, sizeof(f), &val, sizeof(val));
if (ret != 0) {
MS_LOG(EXCEPTION) << "Uint32ToFloat failed, memcpy_s errorno: " << ret;
}
return f - 1.0f;
}

View File

@ -19,6 +19,7 @@
#include <set>
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "mindspore/core/abstract/ops/op_infer.h"
namespace mindspore {
namespace ops {
@ -30,47 +31,38 @@ void LogUniformCandidateSampler::Init(int64_t num_true, int64_t num_sampled, boo
this->set_range_max(range_max);
this->set_seed(seed);
}
namespace {
abstract::TupleShapePtr LogUniformCandidateSamplerInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr(kNumSampled));
auto sampled_candidate_shape = std::make_shared<abstract::Shape>(ShapeVector({num_sampled}));
auto true_expected_shape = input_args[0]->BuildShape();
std::vector<abstract::BaseShapePtr> shape_tuple;
(void)shape_tuple.emplace_back(sampled_candidate_shape);
(void)shape_tuple.emplace_back(true_expected_shape);
(void)shape_tuple.emplace_back(sampled_candidate_shape);
return std::make_shared<abstract::TupleShape>(shape_tuple);
}
class LogUniformCandidateSamplerInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr(kNumSampled));
auto sampled_candidate_shape = std::make_shared<abstract::Shape>(ShapeVector({num_sampled}));
auto true_expected_shape = input_args[0]->BuildShape();
TuplePtr LogUniformCandidateSamplerInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// check input data type
const std::set<TypePtr> valid_types = {kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("true_classes", input_args[0]->BuildType(), valid_types,
primitive->name());
std::vector<abstract::BaseShapePtr> shape_tuple;
(void)shape_tuple.emplace_back(sampled_candidate_shape);
(void)shape_tuple.emplace_back(true_expected_shape);
(void)shape_tuple.emplace_back(sampled_candidate_shape);
return std::make_shared<abstract::TupleShape>(shape_tuple);
}
// return outputs data type
auto sampled_candidate_type = std::make_shared<TensorType>(kInt64);
auto true_expected_type = std::make_shared<TensorType>(kFloat32);
auto sampled_expected = std::make_shared<TensorType>(kFloat32);
return std::make_shared<Tuple>(std::vector<TypePtr>{sampled_candidate_type, true_expected_type, sampled_expected});
}
} // namespace
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
// check input data type
const std::set<TypePtr> valid_types = {kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("true_classes", input_args[0]->BuildType(), valid_types,
primitive->name());
AbstractBasePtr LogUniformCandidateSamplerInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
auto type = LogUniformCandidateSamplerInferType(primitive, input_args);
auto shape = LogUniformCandidateSamplerInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
// return outputs data type
auto sampled_candidate_type = std::make_shared<TensorType>(kInt64);
auto true_expected_type = std::make_shared<TensorType>(kFloat32);
auto sampled_expected = std::make_shared<TensorType>(kFloat32);
return std::make_shared<Tuple>(std::vector<TypePtr>{sampled_candidate_type, true_expected_type, sampled_expected});
}
};
MIND_API_OPERATOR_IMPL(LogUniformCandidateSampler, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LogUniformCandidateSampler, prim::kPrimLogUniformCandidateSampler,
LogUniformCandidateSamplerInfer, nullptr, true);
REGISTER_PRIMITIVE_OP_INFER_IMPL(LogUniformCandidateSampler, prim::kPrimLogUniformCandidateSampler,
LogUniformCandidateSamplerInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -70,10 +70,6 @@ class MIND_API LogUniformCandidateSampler : public BaseOperator {
inline int64_t get_seed() { return GetValue<int64_t>(GetAttr(kSeed)); }
};
abstract::AbstractBasePtr LogUniformCandidateSamplerInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Defines parameter operators with functional form."""
import numpy as np
@ -72,8 +71,7 @@ def random_gamma(shape, alpha, seed=0, seed2=0):
beta = Tensor(np.array([1.0]), alpha_type)
alpha_shape = P.Shape()(alpha)
beta_shape = P.Shape()(beta)
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma",
arg_name1="alpha", arg_name2="beta")
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma", arg_name1="alpha", arg_name2="beta")
broadcast_shape_t = tuple(broadcast_shape)
broadcast_to = P.BroadcastTo(broadcast_shape_t)
alpha_broadcast = broadcast_to(alpha)
@ -123,8 +121,7 @@ def standard_laplace(shape, seed=0, seed2=0):
>>> print(result)
(4, 4)
"""
standard_laplace_op = _get_cache_prim(
P.StandardLaplace)(seed=seed, seed2=seed2)
standard_laplace_op = _get_cache_prim(P.StandardLaplace)(seed=seed, seed2=seed2)
return standard_laplace_op(shape)
@ -223,13 +220,11 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
(3, 2, 2)
"""
if not isinstance(minval, Tensor) or not isinstance(maxval, Tensor):
raise TypeError(
f"For functional operator[uniform], the input[minval] and input[maxval] must be a Tensor.")
raise TypeError(f"For functional operator[uniform], the input[minval] and input[maxval] must be a Tensor.")
minval_dtype = F.dtype(minval)
maxval_dtype = F.dtype(maxval)
const_utils.check_type_valid(
dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_type_valid(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
seed1, seed2 = _get_seed(seed, "uniform")
@ -277,12 +272,16 @@ def standard_normal(shape, seed=0, seed2=0):
>>> print(result)
(4, 4)
"""
standard_normal_op = _get_cache_prim(
P.StandardNormal)(seed=seed, seed2=seed2)
standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed, seed2=seed2)
return standard_normal_op(shape)
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0,
def uniform_candidate_sampler(true_classes,
num_true,
num_sampled,
unique,
range_max,
seed=0,
remove_accidental_hits=False):
r"""
Uniform candidate sampler.
@ -328,10 +327,13 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range
>>> print(output3.shape)
(3,)
"""
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed=seed,
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true,
num_sampled,
unique,
range_max,
seed=seed,
remove_accidental_hits=remove_accidental_hits)
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(
true_classes)
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes)
return sampled_candidates, true_expected_count, sampled_expected_count
@ -425,15 +427,14 @@ def shuffle(x, seed=None):
return output
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=0):
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0):
r"""
Generates random labels with a log-uniform distribution for sampled_candidates.
Randomly samples a tensor of sampled classes from the range of integers [0, range_max).
Args:
true_classes (Tensor) - The target classes. With data type of int64 and
true_classes (Tensor): The target classes. With data type of int64 and
shape :math:`(batch\_size, num\_true)` .
num_true (int): The number of target classes per training example. Default: 1.
num_sampled (int): The number of classes to randomly sample. Default: 5.
@ -454,13 +455,15 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
Raises:
TypeError: If neither `num_true` nor `num_sampled` is an int.
TypeError: If `unique` is not a bool.
TypeError: If neither `range_max` nor `seed` is an int.
TypeError: If `true_classes` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> from mindspore.ops import functional as F
>>> output1, output2, output3 = F.log_uniform_candidate_sampler(
>>> from mindspore import Tensor, ops
>>> output1, output2, output3 = ops.log_uniform_candidate_sampler(
... Tensor(np.array([[1, 7], [0, 4], [3, 3]])), 2, 5, True, 5)
>>> print(output1, output2, output3)
[3 2 0 4 1]
@ -471,8 +474,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
"""
sampler = _get_cache_prim(P.LogUniformCandidateSampler)(
num_true, num_sampled, unique, range_max, seed)
sampler = _get_cache_prim(P.LogUniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed)
return sampler(true_classes)
@ -516,22 +518,13 @@ def choice_with_mask(input_x, count=256, seed=0, seed2=0):
>>> print(result)
(256,)
"""
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(
count=count, seed=seed, seed2=seed2)
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed, seed2=seed2)
output = choice_with_mask_(input_x)
return output
__all__ = [
'standard_laplace',
'random_categorical',
'uniform',
'standard_normal',
'random_gamma',
'uniform_candidate_sampler',
'random_poisson',
'log_uniform_candidate_sampler',
'shuffle',
'choice_with_mask'
'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', 'random_gamma',
'uniform_candidate_sampler', 'random_poisson', 'log_uniform_candidate_sampler', 'shuffle', 'choice_with_mask'
]
__all__.sort()