forked from mindspore-Ecosystem/mindspore
!44521 [MS]solve the review comments for log_uniform_candidate_sampler primitive
Merge pull request !44521 from zhaizhiqiang/master
This commit is contained in:
commit
9b86178885
|
@ -93,6 +93,7 @@ mindspore.ops.function
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.grid_sample
|
||||
mindspore.ops.log_uniform_candidate_sampler
|
||||
mindspore.ops.uniform_candidate_sampler
|
||||
|
||||
距离函数
|
||||
|
@ -105,17 +106,6 @@ mindspore.ops.function
|
|||
|
||||
mindspore.ops.cdist
|
||||
|
||||
|
||||
采样函数
|
||||
^^^^^^^^^^
|
||||
|
||||
.. mscnplatformautosummary::
|
||||
:toctree: ops
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.log_uniform_candidate_sampler
|
||||
|
||||
数学运算函数
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -94,6 +94,7 @@ Sampling Functions
|
|||
:template: classtemplate.rst
|
||||
|
||||
mindspore.ops.grid_sample
|
||||
mindspore.ops.log_uniform_candidate_sampler
|
||||
mindspore.ops.uniform_candidate_sampler
|
||||
|
||||
Distance Functions
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <random>
|
||||
#include "securec/include/securec.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/random_util.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -69,7 +70,10 @@ double Uint64ToDouble(uint32_t x0, uint32_t x1) {
|
|||
float Uint32ToFloat(uint32_t x) {
|
||||
uint32_t val = (127 << 23) | (x & 0x7fffffu);
|
||||
float f;
|
||||
memcpy(&f, &val, sizeof(val));
|
||||
auto ret = memcpy_s(&f, sizeof(f), &val, sizeof(val));
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Uint32ToFloat failed, memcpy_s errorno: " << ret;
|
||||
}
|
||||
return f - 1.0f;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <set>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "mindspore/core/abstract/ops/op_infer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -30,47 +31,38 @@ void LogUniformCandidateSampler::Init(int64_t num_true, int64_t num_sampled, boo
|
|||
this->set_range_max(range_max);
|
||||
this->set_seed(seed);
|
||||
}
|
||||
namespace {
|
||||
abstract::TupleShapePtr LogUniformCandidateSamplerInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr(kNumSampled));
|
||||
auto sampled_candidate_shape = std::make_shared<abstract::Shape>(ShapeVector({num_sampled}));
|
||||
auto true_expected_shape = input_args[0]->BuildShape();
|
||||
|
||||
std::vector<abstract::BaseShapePtr> shape_tuple;
|
||||
(void)shape_tuple.emplace_back(sampled_candidate_shape);
|
||||
(void)shape_tuple.emplace_back(true_expected_shape);
|
||||
(void)shape_tuple.emplace_back(sampled_candidate_shape);
|
||||
return std::make_shared<abstract::TupleShape>(shape_tuple);
|
||||
}
|
||||
class LogUniformCandidateSamplerInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
int64_t num_sampled = GetValue<int64_t>(primitive->GetAttr(kNumSampled));
|
||||
auto sampled_candidate_shape = std::make_shared<abstract::Shape>(ShapeVector({num_sampled}));
|
||||
auto true_expected_shape = input_args[0]->BuildShape();
|
||||
|
||||
TuplePtr LogUniformCandidateSamplerInferType(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// check input data type
|
||||
const std::set<TypePtr> valid_types = {kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("true_classes", input_args[0]->BuildType(), valid_types,
|
||||
primitive->name());
|
||||
std::vector<abstract::BaseShapePtr> shape_tuple;
|
||||
(void)shape_tuple.emplace_back(sampled_candidate_shape);
|
||||
(void)shape_tuple.emplace_back(true_expected_shape);
|
||||
(void)shape_tuple.emplace_back(sampled_candidate_shape);
|
||||
return std::make_shared<abstract::TupleShape>(shape_tuple);
|
||||
}
|
||||
|
||||
// return outputs data type
|
||||
auto sampled_candidate_type = std::make_shared<TensorType>(kInt64);
|
||||
auto true_expected_type = std::make_shared<TensorType>(kFloat32);
|
||||
auto sampled_expected = std::make_shared<TensorType>(kFloat32);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{sampled_candidate_type, true_expected_type, sampled_expected});
|
||||
}
|
||||
} // namespace
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
// check input data type
|
||||
const std::set<TypePtr> valid_types = {kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("true_classes", input_args[0]->BuildType(), valid_types,
|
||||
primitive->name());
|
||||
|
||||
AbstractBasePtr LogUniformCandidateSamplerInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
|
||||
auto type = LogUniformCandidateSamplerInferType(primitive, input_args);
|
||||
auto shape = LogUniformCandidateSamplerInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
// return outputs data type
|
||||
auto sampled_candidate_type = std::make_shared<TensorType>(kInt64);
|
||||
auto true_expected_type = std::make_shared<TensorType>(kFloat32);
|
||||
auto sampled_expected = std::make_shared<TensorType>(kFloat32);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{sampled_candidate_type, true_expected_type, sampled_expected});
|
||||
}
|
||||
};
|
||||
|
||||
MIND_API_OPERATOR_IMPL(LogUniformCandidateSampler, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogUniformCandidateSampler, prim::kPrimLogUniformCandidateSampler,
|
||||
LogUniformCandidateSamplerInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(LogUniformCandidateSampler, prim::kPrimLogUniformCandidateSampler,
|
||||
LogUniformCandidateSamplerInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,10 +70,6 @@ class MIND_API LogUniformCandidateSampler : public BaseOperator {
|
|||
|
||||
inline int64_t get_seed() { return GetValue<int64_t>(GetAttr(kSeed)); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr LogUniformCandidateSamplerInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Defines parameter operators with functional form."""
|
||||
|
||||
import numpy as np
|
||||
|
@ -72,8 +71,7 @@ def random_gamma(shape, alpha, seed=0, seed2=0):
|
|||
beta = Tensor(np.array([1.0]), alpha_type)
|
||||
alpha_shape = P.Shape()(alpha)
|
||||
beta_shape = P.Shape()(beta)
|
||||
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma",
|
||||
arg_name1="alpha", arg_name2="beta")
|
||||
broadcast_shape = get_broadcast_shape(alpha_shape, beta_shape, "random_gamma", arg_name1="alpha", arg_name2="beta")
|
||||
broadcast_shape_t = tuple(broadcast_shape)
|
||||
broadcast_to = P.BroadcastTo(broadcast_shape_t)
|
||||
alpha_broadcast = broadcast_to(alpha)
|
||||
|
@ -123,8 +121,7 @@ def standard_laplace(shape, seed=0, seed2=0):
|
|||
>>> print(result)
|
||||
(4, 4)
|
||||
"""
|
||||
standard_laplace_op = _get_cache_prim(
|
||||
P.StandardLaplace)(seed=seed, seed2=seed2)
|
||||
standard_laplace_op = _get_cache_prim(P.StandardLaplace)(seed=seed, seed2=seed2)
|
||||
return standard_laplace_op(shape)
|
||||
|
||||
|
||||
|
@ -223,13 +220,11 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
|
|||
(3, 2, 2)
|
||||
"""
|
||||
if not isinstance(minval, Tensor) or not isinstance(maxval, Tensor):
|
||||
raise TypeError(
|
||||
f"For functional operator[uniform], the input[minval] and input[maxval] must be a Tensor.")
|
||||
raise TypeError(f"For functional operator[uniform], the input[minval] and input[maxval] must be a Tensor.")
|
||||
|
||||
minval_dtype = F.dtype(minval)
|
||||
maxval_dtype = F.dtype(maxval)
|
||||
const_utils.check_type_valid(
|
||||
dtype, [mstype.int32, mstype.float32], 'uniform')
|
||||
const_utils.check_type_valid(dtype, [mstype.int32, mstype.float32], 'uniform')
|
||||
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
|
||||
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
|
||||
seed1, seed2 = _get_seed(seed, "uniform")
|
||||
|
@ -277,12 +272,16 @@ def standard_normal(shape, seed=0, seed2=0):
|
|||
>>> print(result)
|
||||
(4, 4)
|
||||
"""
|
||||
standard_normal_op = _get_cache_prim(
|
||||
P.StandardNormal)(seed=seed, seed2=seed2)
|
||||
standard_normal_op = _get_cache_prim(P.StandardNormal)(seed=seed, seed2=seed2)
|
||||
return standard_normal_op(shape)
|
||||
|
||||
|
||||
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0,
|
||||
def uniform_candidate_sampler(true_classes,
|
||||
num_true,
|
||||
num_sampled,
|
||||
unique,
|
||||
range_max,
|
||||
seed=0,
|
||||
remove_accidental_hits=False):
|
||||
r"""
|
||||
Uniform candidate sampler.
|
||||
|
@ -328,10 +327,13 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range
|
|||
>>> print(output3.shape)
|
||||
(3,)
|
||||
"""
|
||||
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed=seed,
|
||||
sampler_op = _get_cache_prim(P.UniformCandidateSampler)(num_true,
|
||||
num_sampled,
|
||||
unique,
|
||||
range_max,
|
||||
seed=seed,
|
||||
remove_accidental_hits=remove_accidental_hits)
|
||||
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(
|
||||
true_classes)
|
||||
sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(true_classes)
|
||||
return sampled_candidates, true_expected_count, sampled_expected_count
|
||||
|
||||
|
||||
|
@ -425,15 +427,14 @@ def shuffle(x, seed=None):
|
|||
return output
|
||||
|
||||
|
||||
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
|
||||
range_max, seed=0):
|
||||
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=0):
|
||||
r"""
|
||||
Generates random labels with a log-uniform distribution for sampled_candidates.
|
||||
|
||||
Randomly samples a tensor of sampled classes from the range of integers [0, range_max).
|
||||
|
||||
Args:
|
||||
true_classes (Tensor) - The target classes. With data type of int64 and
|
||||
true_classes (Tensor): The target classes. With data type of int64 and
|
||||
shape :math:`(batch\_size, num\_true)` .
|
||||
num_true (int): The number of target classes per training example. Default: 1.
|
||||
num_sampled (int): The number of classes to randomly sample. Default: 5.
|
||||
|
@ -454,13 +455,15 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
|
|||
Raises:
|
||||
TypeError: If neither `num_true` nor `num_sampled` is an int.
|
||||
TypeError: If `unique` is not a bool.
|
||||
TypeError: If neither `range_max` nor `seed` is an int.
|
||||
TypeError: If `true_classes` is not a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> output1, output2, output3 = F.log_uniform_candidate_sampler(
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> output1, output2, output3 = ops.log_uniform_candidate_sampler(
|
||||
... Tensor(np.array([[1, 7], [0, 4], [3, 3]])), 2, 5, True, 5)
|
||||
>>> print(output1, output2, output3)
|
||||
[3 2 0 4 1]
|
||||
|
@ -471,8 +474,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
|
|||
|
||||
"""
|
||||
|
||||
sampler = _get_cache_prim(P.LogUniformCandidateSampler)(
|
||||
num_true, num_sampled, unique, range_max, seed)
|
||||
sampler = _get_cache_prim(P.LogUniformCandidateSampler)(num_true, num_sampled, unique, range_max, seed)
|
||||
return sampler(true_classes)
|
||||
|
||||
|
||||
|
@ -516,22 +518,13 @@ def choice_with_mask(input_x, count=256, seed=0, seed2=0):
|
|||
>>> print(result)
|
||||
(256,)
|
||||
"""
|
||||
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(
|
||||
count=count, seed=seed, seed2=seed2)
|
||||
choice_with_mask_ = _get_cache_prim(RandomChoiceWithMask)(count=count, seed=seed, seed2=seed2)
|
||||
output = choice_with_mask_(input_x)
|
||||
return output
|
||||
|
||||
|
||||
__all__ = [
|
||||
'standard_laplace',
|
||||
'random_categorical',
|
||||
'uniform',
|
||||
'standard_normal',
|
||||
'random_gamma',
|
||||
'uniform_candidate_sampler',
|
||||
'random_poisson',
|
||||
'log_uniform_candidate_sampler',
|
||||
'shuffle',
|
||||
'choice_with_mask'
|
||||
'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', 'random_gamma',
|
||||
'uniform_candidate_sampler', 'random_poisson', 'log_uniform_candidate_sampler', 'shuffle', 'choice_with_mask'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
Loading…
Reference in New Issue