!44753 [feat] [assistant] [I4XJID] Add MultinomialWithReplacement

Merge pull request !44753 from 桂宁馨/MultinomialWithReplacement
This commit is contained in:
i-robot 2022-11-12 06:48:27 +00:00 committed by Gitee
commit 60c1dde35b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 671 additions and 0 deletions

View File

@ -0,0 +1,294 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/multinomial_with_replacement_cpu_kernel.h"
#include <Eigen/Dense>
#include <algorithm>
#include <map>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include <cfloat>
#include <cmath>
#include <iostream>
#include <functional>
#include <random>
#include "mindspore/core/ops/multinomial_with_replacement.h"
#include "kernel/common_utils.h"
#include "utils/ms_utils.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
const size_t kMultinomialWithReplacementInputsNum = 3;
const size_t kMultinomialWithReplacementOutputsNum = 1;
} // namespace
uint64_t MultinomialWithReplacementCpuKernelMod::New64() {
std::random_device device("/dev/urandom");
static std::mt19937_64 rng = std::mt19937_64(device());
return (rng)();
}
void MultinomialWithReplacementCpuKernelMod::InitMSPhiloxRandom(int64_t seed_, int64_t offset_) {
if (seed_ == 0 && offset_ == 0) {
seed_ = New64();
offset_ = New64();
}
generator_ = random::MSPhiloxRandom(seed_, offset_);
}
float MultinomialWithReplacementCpuKernelMod::RandFloat() {
uint32_t x = GenerateSingle();
const uint32_t man = x & 0x7fffffu; // 23 bit mantissa
const uint32_t exp = static_cast<uint32_t>(127);
const uint32_t val = (exp << 23) | man;
float result;
memcpy_s(&result, sizeof(result), &val, sizeof(val));
return result - 1.0f;
}
uint32_t MultinomialWithReplacementCpuKernelMod::GenerateSingle() {
if (used_result_index_ == random::MSPhiloxRandom::kResultElementCount) {
unused_results_ = generator_();
used_result_index_ = 0;
}
return unused_results_[used_result_index_++];
}
bool MultinomialWithReplacementCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
auto op = std::dynamic_pointer_cast<ops::MultinomialWithReplacement>(base_operator);
kernel_name_ = op->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
kernel_ptr_ = std::make_shared<ops::MultinomialWithReplacement>(base_operator->GetPrim());
if (!is_match) {
MS_LOG(ERROR) << "MultinomialWithReplacement does not support this kernel data type: " << kernel_attr;
return false;
}
numsamples_ = op->get_numsamples();
replacement_ = op->get_replacement();
x_shape_ = inputs[0]->GetShapeVector();
kernel_func_ = func_list_[index].second;
return true;
}
int MultinomialWithReplacementCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = KRET_OK;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
std::vector<int64_t> input_shape = inputs.at(kIndex0)->GetShapeVector();
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
return ret;
}
template <typename T>
bool MultinomialWithReplacementCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kMultinomialWithReplacementInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kMultinomialWithReplacementOutputsNum, kernel_name_);
if (numsamples_ <= 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', 'numsamples' should be a nonnegative number, but got "
<< numsamples_ << ".";
}
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto seed_ = reinterpret_cast<int64_t>(inputs[1]->addr);
auto offset_ = reinterpret_cast<int64_t>(inputs[2]->addr);
InitMSPhiloxRandom(seed_, offset_);
int64_t num_row_ = 1;
size_t num_shape = 2;
if (x_shape_.size() == num_shape) {
num_row_ = x_shape_[0];
}
int64_t num_col_ = x_shape_[x_shape_.size() - 1];
for (int i = 0; i < num_row_; i++) {
double sum = 0;
auto row_start = x + i * num_col_;
for (int64_t j = 0; j < num_col_; ++j) {
if (static_cast<double>(*(row_start + j)) < 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "' , each element of 'x' must be equal or greater than 0. ";
}
sum += static_cast<double>(*(row_start + j));
}
if (sum <= 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "' , the sum of each row of 'x' must be greater than 0. ";
}
}
int64_t output_size = num_row_ * numsamples_;
std::vector<T> RandomData(output_size);
for (int64_t i = 0; i < output_size; i++) {
RandomData[i] = static_cast<T>(RandFloat());
}
auto y = reinterpret_cast<int64_t *>(outputs[0]->addr);
for (int64_t i = 0; i < num_row_; i++) {
if (replacement_ == true) {
auto out = y + i * numsamples_;
auto in = x + i * num_col_;
out = TrueCompute<T>(in, out, RandomData.data(), i, num_col_);
} else {
auto out = y + i * numsamples_;
auto in = x + i * num_col_;
out = FalseCompute<T>(in, out, RandomData.data(), i, num_col_);
}
}
return true;
}
template <typename T>
int64_t *MultinomialWithReplacementCpuKernelMod::TrueCompute(T *in, int64_t *out, T *RandomData, int64_t i,
int64_t num_col_) {
double *cumulative_distribution_function = new double[num_col_];
double running_total = 0;
auto random = RandomData + i * numsamples_;
for (int64_t j = 0; j < num_col_; ++j) {
*(cumulative_distribution_function + j) = static_cast<double>(*(in + j));
}
for (int64_t j = 0; j < num_col_; ++j) {
if (*(cumulative_distribution_function + j) != 0.0) {
running_total += *(cumulative_distribution_function + j);
*(cumulative_distribution_function + j) = running_total;
}
}
for (int64_t j = 0; j < numsamples_; j++) {
double rand = static_cast<double>(*(random + j));
double rr = rand * running_total;
auto rt = running_total;
double *temp = &rt;
for (int k = 0; k < num_col_; k++) {
if (*(cumulative_distribution_function + k) >= rr && *(cumulative_distribution_function + k) <= *temp) {
*temp = *(cumulative_distribution_function + k);
}
}
for (int k = 0; k < num_col_; k++) {
if (*temp == *(cumulative_distribution_function + k)) {
*out = static_cast<int64_t>(k);
}
}
out = out + 1;
}
return out;
}
template <typename T>
int64_t *MultinomialWithReplacementCpuKernelMod::FalseCompute(T *in, int64_t *out, T *RandomData, int64_t i,
int64_t num_col_) {
double *cumulative_distribution_function = new double[num_col_];
T *weight = new T[num_col_];
int64_t zero_num = 0;
int64_t *zero_data = new int64_t[num_col_];
double running_total = 0;
auto random = RandomData + i * numsamples_;
std::copy_n(in, num_col_, weight);
for (int64_t j = 0; j < num_col_; ++j) {
*(cumulative_distribution_function + j) = static_cast<double>(*(in + j));
}
for (int64_t j = 0; j < num_col_; ++j) {
if (*(cumulative_distribution_function + j) != 0.0) {
running_total += *(cumulative_distribution_function + j);
*(cumulative_distribution_function + j) = running_total;
} else {
*(zero_data + zero_num) = static_cast<int64_t>(j);
zero_num = zero_num + 1;
}
}
for (int j = 0; j < numsamples_; j++) {
double rand = static_cast<double>(*(random + j));
double rr = rand * running_total;
auto rt = running_total;
double *temp = &rt;
if (j < num_col_ - zero_num) {
for (int k = 0; k < num_col_; k++) {
if (*(cumulative_distribution_function + k) >= rr && *(cumulative_distribution_function + k) <= *temp) {
*temp = *(cumulative_distribution_function + k);
}
}
for (int k = 0; k < num_col_; k++) {
if (*temp == *(cumulative_distribution_function + k)) {
*out = static_cast<int64_t>(k);
}
}
int co = *out;
*(weight + co) = static_cast<T>(0.0);
running_total = 0.0;
for (int64_t t = 0; t < num_col_; t++) {
*(cumulative_distribution_function + t) = static_cast<double>(*(weight + t));
}
for (int64_t t = 0; t < num_col_; t++) {
if (*(cumulative_distribution_function + t) != 0.0) {
running_total += *(cumulative_distribution_function + t);
*(cumulative_distribution_function + t) = running_total;
}
}
out = out + 1;
} else {
*out = *(zero_data + j - num_col_ + zero_num);
out = out + 1;
}
}
return out;
}
std::vector<std::pair<KernelAttr, MultinomialWithReplacementCpuKernelMod::MultinomialWithReplacementFunc>>
MultinomialWithReplacementCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MultinomialWithReplacementCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MultinomialWithReplacementCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&MultinomialWithReplacementCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MultinomialWithReplacementCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MultinomialWithReplacementFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, MultinomialWithReplacement, MultinomialWithReplacementCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL_H_
#include <vector>
#include <map>
#include <string>
#include <cmath>
#include <random>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/random_util.h"
namespace mindspore {
namespace kernel {
class MultinomialWithReplacementCpuKernelMod : public NativeCpuKernelMod {
public:
MultinomialWithReplacementCpuKernelMod() = default;
~MultinomialWithReplacementCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
bool CheckMultinomialWithReplacementShape();
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using MultinomialWithReplacementFunc =
std::function<bool(MultinomialWithReplacementCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
template <typename T>
int64_t *TrueCompute(T *in, int64_t *out, T *RandomData, int64_t i, int64_t num_col_);
template <typename T>
int64_t *FalseCompute(T *in, int64_t *out, T *RandomData, int64_t i, int64_t num_col_);
private:
random::MSPhiloxRandom generator_;
using ResType = random::Array<uint32_t, random::MSPhiloxRandom::kResultElementCount>;
ResType unused_results_;
size_t used_result_index_ = random::MSPhiloxRandom::kResultElementCount;
float RandFloat();
uint64_t New64();
void InitMSPhiloxRandom(int64_t seed, int64_t offset);
uint32_t GenerateSingle();
static std::vector<std::pair<KernelAttr, MultinomialWithReplacementFunc>> func_list_;
MultinomialWithReplacementFunc kernel_func_;
ShapeVector x_shape_;
std::vector<size_t> input_shape_;
int64_t numsamples_;
bool replacement_;
int64_t num_row_;
int64_t num_col_;
BaseOperatorPtr kernel_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_MULTINOMIAL_WITH_REPLACEMENT_CPU_KERNEL_H_

View File

@ -1572,6 +1572,7 @@ GVAR_DEF(PrimitivePtr, kPrimRandperm, std::make_shared<Primitive>("Randperm"));
GVAR_DEF(PrimitivePtr, kPrimUniformCandidateSampler, std::make_shared<Primitive>("UniformCandidateSampler"));
GVAR_DEF(PrimitivePtr, kPrimLogUniformCandidateSampler, std::make_shared<Primitive>("LogUniformCandidateSampler"));
GVAR_DEF(PrimitivePtr, kPrimMultinomial, std::make_shared<Primitive>("Multinomial"));
GVAR_DEF(PrimitivePtr, kPrimMultinomialWithReplacement, std::make_shared<Primitive>("MultinomialWithReplacement"));
GVAR_DEF(PrimitivePtr, kPrimRandomChoiceWithMask, std::make_shared<Primitive>("RandomChoiceWithMask"));
// RL Ops

View File

@ -0,0 +1,123 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/multinomial_with_replacement.h"
#include <algorithm>
#include <set>
#include <string>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/ms_context.h"
#include "abstract/ops/primitive_infer_map.h"
#include "abstract/param_validator.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
ShapeVector output_sizeList;
} // namespace
void MultinomialWithReplacement::Init(int64_t numsamples, bool replacement) {
this->set_numsamples(numsamples);
this->set_replacement(replacement);
}
void MultinomialWithReplacement::set_numsamples(int64_t numsamples) {
(void)this->AddAttr("numsamples", api::MakeValue(numsamples));
}
int64_t MultinomialWithReplacement::get_numsamples() const {
auto numsamples = this->GetAttr("numsamples");
MS_EXCEPTION_IF_NULL(numsamples);
return GetValue<int64_t>(numsamples);
}
void MultinomialWithReplacement::set_replacement(bool replacement) {
(void)this->AddAttr("replacement", api::MakeValue(replacement));
}
bool MultinomialWithReplacement::get_replacement() const {
auto replacement = this->GetAttr("replacement");
MS_EXCEPTION_IF_NULL(replacement);
return GetValue<bool>(replacement);
}
abstract::BaseShapePtr MultinomialWithReplacementInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t x_rank_max = 2;
const int64_t x_rank_min = 1;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
std::vector<int64_t> y_shape;
if (x_shape.size() > x_rank_max || x_shape.size() < x_rank_min) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'x' must have a rank of 1 or 2, but got rank "
<< x_shape.size() << ".";
}
auto numsamples_ptr = primitive->GetAttr("numsamples");
auto numsamples = GetValue<int64_t>(numsamples_ptr);
auto replacement_ptr = primitive->GetAttr("replacement");
auto replacement = GetValue<bool>(replacement_ptr);
if (x_shape.size() == x_rank_min) {
if (replacement == false && x_shape[0] < numsamples) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', the value of numsamples must equal or less than x_shape[-1], but got "
<< numsamples << ".";
}
y_shape.push_back(numsamples);
} else {
if (replacement == false && x_shape[1] < numsamples) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', the value of numsamples must equal or less than x_shape[-1], but got "
<< numsamples << ".";
}
y_shape.push_back(x_shape[0]);
y_shape.push_back(numsamples);
}
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr MultinomialWithReplacementInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto x_dtype = input_args[0]->BuildType();
auto seed_dtype = input_args[1]->BuildType();
auto offset_dtype = input_args[2]->BuildType();
TypePtr y_type = {kInt64};
const std::set<TypePtr> valid_types_x = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> valid_types_seed = {kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, valid_types_x, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("seed_dtype", seed_dtype, valid_types_seed, op_name);
CheckAndConvertUtils::CheckTensorTypeValid("offset_dtype", offset_dtype, valid_types_seed, op_name);
return y_type;
}
MIND_API_OPERATOR_IMPL(MultinomialWithReplacement, BaseOperator);
AbstractBasePtr MultinomialWithReplacementInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto types = MultinomialWithReplacementInferType(primitive, input_args);
auto shapes = MultinomialWithReplacementInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_PRIMITIVE_EVAL_IMPL(MultinomialWithReplacement, prim::kPrimMultinomialWithReplacement,
MultinomialWithReplacementInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MULTINOMIAL_WITH_REPLACEMENT_H_
#define MINDSPORE_CORE_OPS_MULTINOMIAL_WITH_REPLACEMENT_H_
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMultinomialWithReplacement = "MultinomialWithReplacement";
class MIND_API MultinomialWithReplacement : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MultinomialWithReplacement);
MultinomialWithReplacement() : BaseOperator(kNameMultinomialWithReplacement) {
InitIOName({"x", "seed", "offset"}, {"y"});
}
void Init(int64_t numsamples, bool replacement = false);
void set_numsamples(int64_t numsamples);
int64_t get_numsamples() const;
void set_replacement(bool replacement);
bool get_replacement() const;
};
AbstractBasePtr MultinomialWithReplacementInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimMultinomialWithReplacementPtr = std::shared_ptr<MultinomialWithReplacement>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MULTINOMIAL_WITH_REPLACEMENT_H_

View File

@ -0,0 +1,35 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MultinomialWithReplacement op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
multinomial_with_replacement_op_info = AiCPURegOp("MultinomialWithReplacement") \
.fusion_type("OPAQUE") \
.input(0, "x", "required") \
.input(1, "seed", "required") \
.input(2, "offset", "required") \
.output(0, "y", "required") \
.attr("numsamples", "int") \
.attr("replacement", "bool") \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(multinomial_with_replacement_op_info)
def _multinomial_with_replacement_aicpu():
"""MultinomialWithReplacement aicpu register"""
return

View File

@ -163,6 +163,50 @@ def random_categorical(logits, num_sample, seed=0, dtype=mstype.int64):
return random_categorical_(logits, num_sample, seed)
def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False):
r"""
Returns a tensor where each row contains `numsamples` elements sampled from the multinomial distribution.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Args:
x (Tensor): the input tensor containing the cumsum of probabilities, must be 1 or 2
dimensions. Must be one of the following types: float16, float32, float64.
seed (int): If seed is set to be -1, and offset is set to be 0, the random number
generator is seeded by a random seed. Otherwise, it is seeded by the given seed.
offset (int): To avoid seed collision.
numsamples (int): the number of samples to draw.
replacement (bool): Whether to draw with replacement or not. Defaults to false.
Returns:
Tensor with the same rows as `x`, each row has numsamples sampled indices.
Raises:
TypeError: If `x` is not a Tensor whose dtype is float16, float32, float64.
TypeError: If `numsamples` is not an int.
TypeError: If `replacement` is not a bool.
ValueError: If `x` rank is not 1 or 2.
ValueError: If the value of `numsamples` must larger than x_shape[-1], when `replacement` is false.
ValueError: If the sum of one row of `x` less than 0.
ValueError: If one of the element of each row of `x` less than 0.
ValueError: If `numsamples` equal or less than 0.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
>>> output = multinomialwithreplacement(x, 2, 5, 2, True)
>>> print(output)
[[1 1]]
"""
multinomial_with_replacement_ = _get_cache_prim(P.MultinomialWithReplacement) \
(numsamples=numsamples, replacement=replacement)
return multinomial_with_replacement_(x, seed, offset)
def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
"""
Generates random numbers according to the Uniform random number distribution.

View File

@ -822,6 +822,35 @@ class Multinomial(Primitive):
Validator.check_type_name("dtype", dtype, valid_values, self.name)
class MultinomialWithReplacement(Primitive):
r"""
Returns a tensor where each row contains numsamples indices sampled from the multinomial distribution.
Note:
The rows of input do not need to sum to one (in which case we use the values as weights),
but must be non-negative, finite and have a non-zero sum.
Refer to :func:`mindspore.ops.multinomial_with_replacement` for more detail.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
>>> multinomialwithreplacement = ops.MultinomialWithReplacement(numsamples=2,replacement=True)
>>> output = multinomialwithreplacement(x, 2, 5)
>>> print(output)
[[1 1]]
"""
@prim_attr_register
def __init__(self, numsamples, replacement=False):
"""Initialize MultinomialWithReplacement."""
Validator.check_non_negative_int(numsamples, "numsamples", self.name)
Validator.check_value_type("replacement", replacement, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'seed', 'offset'], outputs=['y'])
class UniformCandidateSampler(PrimitiveWithInfer):
r"""
Uniform candidate sampler.

View File

@ -77,6 +77,7 @@ from mindspore.ops.operations.array_ops import ScatterAddWithAxis
from mindspore.ops.operations.array_ops import ConcatOffsetV1
from mindspore.ops.operations.random_ops import NonDeterministicInts
from mindspore.ops.operations.random_ops import TruncatedNormal
from mindspore.ops.operations.random_ops import MultinomialWithReplacement
from mindspore.ops.operations.random_ops import ParameterizedTruncatedNormal
from mindspore.ops.operations.random_ops import LogNormalReverse
from mindspore.ops.operations.image_ops import NonMaxSuppressionWithOverlaps
@ -4190,6 +4191,12 @@ test_case_other_ops = [
'block': TruncatedNormal(dtype=mstype.float32, seed=1, seed2=1),
'desc_inputs': [Tensor(np.array([2, 2]), mstype.int32)],
'skip': ['backward']}),
('MultinomialWithReplacement', {
'block': MultinomialWithReplacement(numsamples=3, replacement=True),
'desc_inputs': [Tensor(np.array([4, 4, 5, 6]), mstype.float32),
Tensor(np.array([1]), mstype.int64),
Tensor(np.array([1]), mstype.int64)],
'skip': ['backward']}),
('ParameterizedTruncatedNormal', {
'block': ParameterizedTruncatedNormal(seed=1, seed2=2),
'desc_inputs': [Tensor(np.array([2, 3]), mstype.int32),