|
|
|
@ -15,16 +15,11 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "plugin/device/cpu/kernel/uniform_candidate_sampler_cpu_kernel.h"
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include "abstract/utils.h"
|
|
|
|
|
#include "mindspore/core/ops/uniform_candidate_sampler.h"
|
|
|
|
|
#include "mindspore/core/ops/core_ops.h"
|
|
|
|
@ -33,7 +28,7 @@ namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename S>
|
|
|
|
|
S Probability(int64_t range_max) {
|
|
|
|
|
if (range_max == 0) {
|
|
|
|
|
if (range_max <= 0) {
|
|
|
|
|
return S(0);
|
|
|
|
|
}
|
|
|
|
|
return static_cast<S>(1.0f / range_max);
|
|
|
|
@ -41,6 +36,7 @@ S Probability(int64_t range_max) {
|
|
|
|
|
|
|
|
|
|
template <typename S>
|
|
|
|
|
S ApproximateExpectedCount(S p, int64_t sampled_size, int64_t counter) {
|
|
|
|
|
// p >= 0 && p < 1.0
|
|
|
|
|
if (sampled_size == counter) {
|
|
|
|
|
return p * sampled_size;
|
|
|
|
|
}
|
|
|
|
@ -53,47 +49,44 @@ const size_t kInputRank = 2;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
int64_t UniformCandidateSamplerCpuKernelMod::Sampling(T *sampled_candidates_, const size_t length) {
|
|
|
|
|
int64_t counter = 0;
|
|
|
|
|
if (length != static_cast<size_t>(num_sampled_ * sizeof(T))) {
|
|
|
|
|
int64_t UniformCandidateSamplerCpuKernelMod::Sampling(T *sampled_candidates_, unsigned int seed, const size_t length) {
|
|
|
|
|
size_t target_length = LongToSize(num_sampled_) * sizeof(T);
|
|
|
|
|
if (length != target_length) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
// pick between [0, range_max_-1]
|
|
|
|
|
T range = static_cast<T>(range_max_);
|
|
|
|
|
T range{0};
|
|
|
|
|
if constexpr (sizeof(T) == sizeof(int64_t)) {
|
|
|
|
|
range = range_max_;
|
|
|
|
|
} else if constexpr (sizeof(T) == sizeof(int32_t)) {
|
|
|
|
|
range = LongToInt(range_max_); // range_max_ less than the max value of ‘int32_t’ number
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unknown type for sampling.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::mt19937 random_generator(seed);
|
|
|
|
|
std::uniform_int_distribution<T> distribution(0, range - 1);
|
|
|
|
|
if (!unique_) {
|
|
|
|
|
auto task = [this, &sampled_candidates_, &distribution](size_t start, size_t end) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
sampled_candidates_[i] = distribution(generator_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
ParallelLaunchAutoSearch(task, num_sampled_, this, ¶llel_search_info_, pool_);
|
|
|
|
|
counter = num_sampled_;
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
for (int64_t i = 0; i < num_sampled_; i++) {
|
|
|
|
|
oss << sampled_candidates_[i] << ", ";
|
|
|
|
|
sampled_candidates_[i] = distribution(random_generator);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "For UniformCandidateSampler, sampled_candidates: " << oss.str();
|
|
|
|
|
return counter;
|
|
|
|
|
return num_sampled_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t picked = 0;
|
|
|
|
|
int64_t counter = 0;
|
|
|
|
|
std::unordered_set<T> set_container;
|
|
|
|
|
while (picked < num_sampled_) {
|
|
|
|
|
T tmp = distribution(generator_);
|
|
|
|
|
T sample = distribution(random_generator);
|
|
|
|
|
counter++;
|
|
|
|
|
if ((set_container.find(tmp) == set_container.end()) &&
|
|
|
|
|
((!remove_accidental_hits_) || set_input_.find(tmp) == set_input_.end())) {
|
|
|
|
|
set_container.insert(tmp);
|
|
|
|
|
sampled_candidates_[picked] = tmp;
|
|
|
|
|
if ((set_container.find(sample) == set_container.end()) &&
|
|
|
|
|
((!remove_accidental_hits_) || set_input_.find(sample) == set_input_.end())) {
|
|
|
|
|
(void)set_container.insert(sample);
|
|
|
|
|
sampled_candidates_[picked] = sample;
|
|
|
|
|
picked++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::ostringstream oss;
|
|
|
|
|
for (int64_t i = 0; i < num_sampled_; i++) {
|
|
|
|
|
oss << sampled_candidates_[i] << ", ";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "For UniformCandidateSampler, sampled_candidates: " << oss.str();
|
|
|
|
|
|
|
|
|
|
return counter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -115,10 +108,10 @@ void UniformCandidateSamplerCpuKernelMod::ExpectedLanuch(const int64_t counter,
|
|
|
|
|
sampled_expected_count[i] = value;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
ParallelLaunchAutoSearch(task2, num_sampled_, this, ¶llel_search_info_, pool_);
|
|
|
|
|
ParallelLaunchAutoSearch(task2, LongToSize(num_sampled_), this, ¶llel_search_info_, pool_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool UniformCandidateSamplerCpuKernelMod::CheckAttribute() {
|
|
|
|
|
void UniformCandidateSamplerCpuKernelMod::CheckAttribute() {
|
|
|
|
|
// check attrs
|
|
|
|
|
if (num_true_ <= 0 || num_sampled_ <= 0 || range_max_ <= 0) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For 'UniformCandidateSampler', the parameters must be larger than 0, but got "
|
|
|
|
@ -131,10 +124,9 @@ bool UniformCandidateSamplerCpuKernelMod::CheckAttribute() {
|
|
|
|
|
<< "'range_max', but got 'num_sampled' = " << num_sampled_
|
|
|
|
|
<< ", 'range_max' = " << range_max_;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vector<KernelTensorPtr> &inputs,
|
|
|
|
|
void UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vector<KernelTensorPtr> &inputs,
|
|
|
|
|
const std::vector<KernelTensorPtr> &outputs) {
|
|
|
|
|
if (inputs.empty() || outputs.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For 'UniformCandidateSampler', inputs or outputs can not be empty.";
|
|
|
|
@ -146,12 +138,12 @@ bool UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
|
|
|
|
|
<< ", outputs' size: " << outputs.size();
|
|
|
|
|
}
|
|
|
|
|
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
|
|
|
|
auto input_rank = static_cast<size_t>(batch_rank_ + kInputRank);
|
|
|
|
|
auto input_rank = LongToSize(batch_rank_) + kInputRank;
|
|
|
|
|
if (input_shape.size() != input_rank) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For 'UniformCandidateSampler', the dimension of input 'true_classes' must be "
|
|
|
|
|
<< input_rank << ", but got " << input_shape.size();
|
|
|
|
|
}
|
|
|
|
|
auto kindex = static_cast<size_t>(batch_rank_ + kIndex1);
|
|
|
|
|
auto kindex = LongToSize(batch_rank_) + kIndex1;
|
|
|
|
|
if (input_shape[kindex] != num_true_) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For 'UniformCandidateSampler', the input 'true_classes' must have 'num_true' columns, "
|
|
|
|
|
<< "but got 'true_classes': (" << input_shape[0] << ", " << input_shape[1] << ")"
|
|
|
|
@ -160,7 +152,7 @@ bool UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
|
|
|
|
|
|
|
|
|
|
auto output_kIndex0_type = outputs.at(kIndex0)->GetDtype();
|
|
|
|
|
if (output_kIndex0_type == kNumberTypeInt32) {
|
|
|
|
|
if (range_max_ > static_cast<int64_t>(std::numeric_limits<int>::max())) {
|
|
|
|
|
if (range_max_ > std::numeric_limits<int>::max()) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', 'range_max' can not exceed the range of int32, but "
|
|
|
|
|
<< "got" << range_max_ << ". The input data type should be changed to int64.";
|
|
|
|
|
}
|
|
|
|
@ -168,11 +160,9 @@ bool UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
|
|
|
|
|
|
|
|
|
|
if (std::any_of(input_shape.begin(), input_shape.end(), [](size_t i) { return i == 0; })) {
|
|
|
|
|
is_null_input_ = true;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
|
|
|
|
|
|
|
|
|
batch_size_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
|
|
|
@ -180,47 +170,47 @@ bool UniformCandidateSamplerCpuKernelMod::CheckInputsAndOutputs(const std::vecto
|
|
|
|
|
if (batch_size_ == 0) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the shape of output 'sampled_candidates' can not be 0";
|
|
|
|
|
}
|
|
|
|
|
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
input_size_ = input_size_ / batch_size_;
|
|
|
|
|
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
|
|
|
|
|
input_size_ = input_size_ / LongToSize(batch_size_);
|
|
|
|
|
|
|
|
|
|
output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
output_sizes_.emplace_back(input_size_);
|
|
|
|
|
output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
(void)output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
(void)output_sizes_.emplace_back(input_size_);
|
|
|
|
|
(void)output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool UniformCandidateSamplerCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
|
|
|
|
const std::vector<KernelTensorPtr> &inputs,
|
|
|
|
|
const std::vector<KernelTensorPtr> &outputs) {
|
|
|
|
|
if (!base_operator) {
|
|
|
|
|
auto kernel_ptr = std::dynamic_pointer_cast<ops::UniformCandidateSampler>(base_operator);
|
|
|
|
|
if (!kernel_ptr) {
|
|
|
|
|
MS_LOG(ERROR) << "UniformCandiadataSampler ops is null.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
kernel_name_ = base_operator->name();
|
|
|
|
|
batch_rank_ = base_operator->get_batch_rank();
|
|
|
|
|
// getting attrs
|
|
|
|
|
if (kernel_name_ == prim::kPrimUniformCandidateSampler->name()) {
|
|
|
|
|
auto kernel_ptr_ = std::make_shared<ops::UniformCandidateSampler>(base_operator->GetPrim());
|
|
|
|
|
num_true_ = kernel_ptr_->get_num_true();
|
|
|
|
|
num_sampled_ = kernel_ptr_->get_num_sampled();
|
|
|
|
|
unique_ = kernel_ptr_->get_unique();
|
|
|
|
|
range_max_ = kernel_ptr_->get_range_max();
|
|
|
|
|
init_seed_ = kernel_ptr_->get_seed();
|
|
|
|
|
if (init_seed_ == 0) {
|
|
|
|
|
cur_seed_ = time(nullptr);
|
|
|
|
|
generator_.seed(cur_seed_);
|
|
|
|
|
} else {
|
|
|
|
|
generator_.seed(init_seed_);
|
|
|
|
|
}
|
|
|
|
|
remove_accidental_hits_ = kernel_ptr_->get_remove_accidental_hits();
|
|
|
|
|
} else {
|
|
|
|
|
kernel_name_ = kernel_ptr->name();
|
|
|
|
|
batch_rank_ = kernel_ptr->get_batch_rank();
|
|
|
|
|
|
|
|
|
|
if (kernel_name_ != prim::kPrimUniformCandidateSampler->name()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "For UniformCandidateSamplerCpuKernelMod, it's name must be UniformCandidateSampler, but got "
|
|
|
|
|
<< "invalid kernel name " << prim::kPrimUniformCandidateSampler->name();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
(void)CheckAttribute();
|
|
|
|
|
(void)CheckInputsAndOutputs(inputs, outputs);
|
|
|
|
|
// get attribute
|
|
|
|
|
num_true_ = kernel_ptr->get_num_true();
|
|
|
|
|
num_sampled_ = kernel_ptr->get_num_sampled();
|
|
|
|
|
unique_ = kernel_ptr->get_unique();
|
|
|
|
|
range_max_ = kernel_ptr->get_range_max();
|
|
|
|
|
int64_t seed_ = kernel_ptr->get_seed();
|
|
|
|
|
remove_accidental_hits_ = kernel_ptr->get_remove_accidental_hits();
|
|
|
|
|
|
|
|
|
|
if (seed_ < 0) {
|
|
|
|
|
MS_EXCEPTION(ValueError) << "For 'UniformCandidateSampler', the parameter 'seed' can not be less than 0, but got: "
|
|
|
|
|
<< seed_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
init_seed_ = LongToUint(seed_);
|
|
|
|
|
// check the attribute, inputs and outputs
|
|
|
|
|
CheckAttribute();
|
|
|
|
|
CheckInputsAndOutputs(inputs, outputs);
|
|
|
|
|
|
|
|
|
|
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
|
|
|
|
return false;
|
|
|
|
@ -233,7 +223,7 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
|
|
|
|
const std::vector<KernelTensorPtr> &outputs,
|
|
|
|
|
const std::map<uint32_t, tensor::TensorPtr> &) {
|
|
|
|
|
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
|
|
|
|
if (ret != KRET_OK) {
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
|
|
|
@ -245,13 +235,13 @@ int UniformCandidateSamplerCpuKernelMod::Resize(const BaseOperatorPtr &base_oper
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
|
|
|
|
input_size_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
|
|
|
|
|
input_size_ = input_size_ / batch_size_;
|
|
|
|
|
input_size_ = LongToSize(std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>()));
|
|
|
|
|
input_size_ = input_size_ / LongToSize(batch_size_);
|
|
|
|
|
|
|
|
|
|
output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
output_sizes_.emplace_back(input_size_);
|
|
|
|
|
output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
return KRET_OK;
|
|
|
|
|
(void)output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
(void)output_sizes_.emplace_back(input_size_);
|
|
|
|
|
(void)output_sizes_.emplace_back(num_sampled_);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
@ -262,33 +252,38 @@ bool UniformCandidateSamplerCpuKernelMod::LaunchKernel(const std::vector<Address
|
|
|
|
|
MS_LOG(WARNING) << "For 'UniformCandidateSampler', the input 'true_classes' was empty.";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (init_seed_ == 0 && cur_seed_ == 0) {
|
|
|
|
|
cur_seed_ = time(nullptr);
|
|
|
|
|
generator_.seed(cur_seed_);
|
|
|
|
|
} else if (init_seed_ != 0) {
|
|
|
|
|
generator_.seed(init_seed_);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "For UniformCandidateSampler, generator seed : init_seed_ = " << init_seed_
|
|
|
|
|
<< "cur_seed_ = " << cur_seed_;
|
|
|
|
|
(void)workspaces;
|
|
|
|
|
|
|
|
|
|
T *sampled_candidates = GetDeviceAddress<T>(outputs, kIndex0);
|
|
|
|
|
S *true_expected_count = GetDeviceAddress<S>(outputs, kIndex1);
|
|
|
|
|
S *sampled_expected_count = GetDeviceAddress<S>(outputs, kIndex2);
|
|
|
|
|
T *input = GetDeviceAddress<T>(inputs, kIndex0);
|
|
|
|
|
|
|
|
|
|
for (int64_t j = 0; j < batch_size_; ++j) {
|
|
|
|
|
unsigned int RNG_seed = 0;
|
|
|
|
|
std::random_device rd;
|
|
|
|
|
if (init_seed_ != 0) {
|
|
|
|
|
RNG_seed = init_seed_;
|
|
|
|
|
} else {
|
|
|
|
|
RNG_seed = rd();
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "For UniformCandidateSampler, generator seed : RNG_seed = " << RNG_seed;
|
|
|
|
|
|
|
|
|
|
if (remove_accidental_hits_) {
|
|
|
|
|
set_input_.clear(); // reset for each batch
|
|
|
|
|
for (size_t i = 0; i < input_size_; i++) {
|
|
|
|
|
set_input_.insert(static_cast<int64_t>(input[i]));
|
|
|
|
|
(void)set_input_.insert(input[i]);
|
|
|
|
|
}
|
|
|
|
|
if (num_sampled_ + static_cast<int64_t>(set_input_.size()) > range_max_) {
|
|
|
|
|
if (num_sampled_ + SizeToLong(set_input_.size()) > range_max_) {
|
|
|
|
|
MS_LOG(WARNING) << "For 'UniformCandidateSampler', the parameter 'range_max' can not be less than the sum of "
|
|
|
|
|
<< "'num_sampled' and the num of unrepeat elements of input 'true_classes', "
|
|
|
|
|
<< " set remove_accidental_hits = false.";
|
|
|
|
|
remove_accidental_hits_ = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
size_t sampled_candidate_size = static_cast<size_t>(num_sampled_ * sizeof(T));
|
|
|
|
|
int64_t counter = Sampling<T>(sampled_candidates, sampled_candidate_size);
|
|
|
|
|
size_t sampled_candidate_size = LongToSize(num_sampled_) * sizeof(T);
|
|
|
|
|
int64_t counter = Sampling<T>(sampled_candidates, RNG_seed, sampled_candidate_size);
|
|
|
|
|
// calculate expected count.
|
|
|
|
|
ExpectedLanuch<S>(counter, true_expected_count, sampled_expected_count);
|
|
|
|
|
|
|
|
|
|
input = input + input_size_;
|
|
|
|
|