!51823 add int32 and seed for uniformcandidatesampler on Ascend
Merge pull request !51823 from yanzhenxiang2020/br_candidate_sampler
This commit is contained in:
commit
0fb84bb2c0
|
@ -126,3 +126,4 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "uninitMemberVar"
|
||||
|
|
|
@ -142,6 +142,22 @@
|
|||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/line_length"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "whitespace/semicolon"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/" "readability/nolint"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "whitespace/end_of_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "whitespace/parens"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "readability/casting"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.cc" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "whitespace/indent"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "whitespace/blank_line"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "build/include_what_you_use"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/candidate_sampler_kernels.h" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/range_sampler.cc" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/range_sampler.h" "build/include_subdir"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/range_sampler.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/common/range_sampler.h" "whitespace/ending_newline"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/expand_dims_kernels.h" "runtime/explicit"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/expand_dims_kernels.h" "build/include_order"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/max_pool_with_argmax_v2_cpu_kernel.h" "build/include_what_you_use"
|
||||
|
|
|
@ -34,11 +34,13 @@ if(EXISTS ${CMAKE_C_COMPILER} AND EXISTS ${CMAKE_CXX_COMPILER})
|
|||
set(AICPU_SRC
|
||||
${PROTO_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/kernel_base.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/range_sampler.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_async_event.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aicpu_sharder/aicpu_pulse.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/random_choice_with_mask_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gather_grad_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/candidate_sampler_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expand_dims_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/flatten_kernels.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reshape_kernels.cc
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "candidate_sampler_kernels.h"
|
||||
#include <algorithm>
|
||||
#include "range_sampler.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
uint32_t CandidateSamplerKernel::ParseKernelParam() {
|
||||
::google::protobuf::Map<::std::string, ::aicpuops::AttrValue> nodedef_attrs = node_def_.attrs();
|
||||
num_true_ = nodedef_attrs["num_true"].i();
|
||||
num_sampled_ = nodedef_attrs["num_sampled"].i();
|
||||
unique_ = nodedef_attrs["unique"].b();
|
||||
range_max_ = nodedef_attrs["range_max"].i();
|
||||
seed_ = nodedef_attrs["seed"].i();
|
||||
|
||||
// input0: true_classes
|
||||
::aicpuops::Tensor x_tensor = node_def_.inputs(0);
|
||||
x_dtype_ = static_cast<::aicpuops::DataType>(x_tensor.tensor_type());
|
||||
const ::aicpuops::TensorShape &x_shape = x_tensor.tensor_shape();
|
||||
for (auto i = 0; i < x_shape.dim_size(); i++) {
|
||||
x_shape_.emplace_back(x_shape.dim(i).size());
|
||||
}
|
||||
|
||||
if (x_shape_.size() != 2) {
|
||||
AICPU_LOGE("true_classes must be a matrix");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
if (x_shape_[1] != num_true_) {
|
||||
AICPU_LOGE(
|
||||
"true_classes must have "
|
||||
"num_true columns, expected: ",
|
||||
x_shape_[1], " was: ", num_true_);
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
batch_size_ = x_shape.dim(0).size();
|
||||
if (x_dtype_ != ::aicpuops::DataType::MS_INT64 && x_dtype_ != ::aicpuops::DataType::MS_INT32) {
|
||||
AICPU_LOGE("invalid type of x_dtype_: %d", x_dtype_);
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
// output_2: sampled_candidates
|
||||
::aicpuops::Tensor true_expected_count_tensor = node_def_.outputs(1);
|
||||
true_expected_count_dtype_ = static_cast<::aicpuops::DataType>(true_expected_count_tensor.tensor_type());
|
||||
if (true_expected_count_dtype_ != ::aicpuops::DataType::MS_FLOAT32) {
|
||||
AICPU_LOGE("invalid type of true_expected_count_dtype_: %d", true_expected_count_dtype_);
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
template <class RangeSamplerType, typename T>
|
||||
uint32_t CandidateSamplerKernel::DoComputeForEachType() {
|
||||
const int64_t batch_size = x_shape_[0];
|
||||
// input
|
||||
T *true_classes = reinterpret_cast<T *>(io_addrs_[0]);
|
||||
std::vector<T> true_candidate_raw(true_classes, true_classes + batch_size * num_true_);
|
||||
std::vector<int64_t> true_candidate(true_candidate_raw.size());
|
||||
std::transform(true_candidate_raw.begin(), true_candidate_raw.end(), true_candidate.begin(),
|
||||
[&](T x) { return static_cast<int64_t>(x); });
|
||||
std::vector<int64_t> sampled_candidate(num_sampled_);
|
||||
std::vector<T> sampled_candidate_raw(num_sampled_);
|
||||
std::vector<float> true_expected_count(batch_size * num_true_);
|
||||
std::vector<float> sampled_expected_count(num_sampled_);
|
||||
|
||||
set_sampler(new RangeSamplerType(range_max_));
|
||||
|
||||
if (unique_ && num_sampled_ > sampler_->range()) {
|
||||
AICPU_LOGE("Sampler's range is too small.");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
sampler_->SampleBatchGetExpectedCount(unique_, seed_, sampled_candidate, sampled_expected_count, true_candidate,
|
||||
true_expected_count);
|
||||
|
||||
std::transform(sampled_candidate.begin(), sampled_candidate.end(), sampled_candidate_raw.begin(),
|
||||
[&](int64_t x) { return static_cast<T>(x); });
|
||||
int true_count_size = batch_size * num_true_ * sizeof(float);
|
||||
int ret1 = memcpy_s(reinterpret_cast<void *>(io_addrs_[1]), num_sampled_ * sizeof(T),
|
||||
(void *)&sampled_candidate_raw.front(), sampled_candidate_raw.size() * sizeof(T));
|
||||
int ret2 = memcpy_s(reinterpret_cast<void *>(io_addrs_[2]), true_count_size, (void *)&true_expected_count.front(),
|
||||
true_count_size);
|
||||
int ret3 = memcpy_s(reinterpret_cast<void *>(io_addrs_[3]), num_sampled_ * sizeof(float),
|
||||
(void *)&sampled_expected_count.front(), sampled_expected_count.size() * sizeof(float));
|
||||
if (ret1 < 0 || ret2 < 0 || ret3 < 0) {
|
||||
AICPU_LOGE("memcpy_s failed!");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
template <class RangeSamplerType>
|
||||
uint32_t CandidateSamplerKernel::CandidateSamplerCompute() {
|
||||
switch (x_dtype_) {
|
||||
case ::aicpuops::DataType::MS_INT32: {
|
||||
DoComputeForEachType<RangeSamplerType, int>();
|
||||
break;
|
||||
}
|
||||
case ::aicpuops::DataType::MS_INT64: {
|
||||
DoComputeForEachType<RangeSamplerType, int64_t>();
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
AICPU_LOGE("CandidateSampler op doesn't support input tensor types.");
|
||||
return kAicpuKernelStateFailed;
|
||||
}
|
||||
}
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
uint32_t LogUniformCandidateSamplerKernel::DoCompute() {
|
||||
LogUniformCandidateSamplerKernel::CandidateSamplerCompute<LogUniformSampler>();
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
uint32_t UniformCandidateSamplerKernel::DoCompute() {
|
||||
UniformCandidateSamplerKernel::CandidateSamplerCompute<UniformSampler>();
|
||||
return kAicpuKernelStateSucess;
|
||||
}
|
||||
|
||||
} // namespace aicpu
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t LogUniformCandidateSampler(void *param) {
|
||||
aicpu::LogUniformCandidateSamplerKernel logUniformCandidateSampler;
|
||||
return logUniformCandidateSampler.Compute(param);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) uint32_t UniformCandidateSampler(void *param) {
|
||||
aicpu::UniformCandidateSamplerKernel uniformCandidateSampler;
|
||||
return uniformCandidateSampler.Compute(param);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef AICPU_LOG_UNIFORM_CANDIDATE_SAMPLER_KERNELS_H
|
||||
#define AICPU_LOG_UNIFORM_CANDIDATE_SAMPLER_KERNELS_H
|
||||
|
||||
#include <utility>
|
||||
#include "common/kernel_base.h"
|
||||
#include "common/kernel_errcode.h"
|
||||
#include "common/range_sampler.h"
|
||||
#include "common/kernel_log.h"
|
||||
#include "proto/node_def.pb.h"
|
||||
#include "proto/attr.pb.h"
|
||||
|
||||
namespace aicpu {
|
||||
class CandidateSamplerKernel : public KernelBase {
|
||||
public:
|
||||
explicit CandidateSamplerKernel(const std::string &kernel_name) : KernelBase(kernel_name){};
|
||||
~CandidateSamplerKernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t ParseKernelParam() override;
|
||||
template <class RangeSamplerType, typename T>
|
||||
uint32_t DoComputeForEachType();
|
||||
template <class RangeSamplerType>
|
||||
uint32_t CandidateSamplerCompute();
|
||||
|
||||
private:
|
||||
int num_true_;
|
||||
int num_sampled_;
|
||||
bool unique_;
|
||||
int64_t range_max_;
|
||||
int64_t seed_;
|
||||
std::unique_ptr<RangeSampler> sampler_;
|
||||
|
||||
int batch_size_ = 0;
|
||||
std::vector<int64_t> x_shape_;
|
||||
|
||||
::aicpuops::DataType x_dtype_ = ::aicpuops::DataType::MS_UNKNOWN;
|
||||
::aicpuops::DataType true_expected_count_dtype_ = ::aicpuops::DataType::MS_UNKNOWN;
|
||||
|
||||
void set_sampler(RangeSampler *sampler) { sampler_.reset(sampler); }
|
||||
|
||||
}; // CandidateSamplerKernel
|
||||
|
||||
class LogUniformCandidateSamplerKernel : public CandidateSamplerKernel {
|
||||
public:
|
||||
explicit LogUniformCandidateSamplerKernel() : CandidateSamplerKernel("LogUniformCandidateSampler"){};
|
||||
~LogUniformCandidateSamplerKernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t DoCompute() override;
|
||||
};
|
||||
|
||||
class UniformCandidateSamplerKernel : public CandidateSamplerKernel {
|
||||
public:
|
||||
explicit UniformCandidateSamplerKernel() : CandidateSamplerKernel("UniformCandidateSampler"){};
|
||||
~UniformCandidateSamplerKernel() = default;
|
||||
|
||||
protected:
|
||||
uint32_t DoCompute() override;
|
||||
};
|
||||
|
||||
} // namespace aicpu
|
||||
#endif // AICPU_LOG_UNIFORM_CANDIDATE_SAMPLER_KERNELS_H
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/range_sampler.h"
|
||||
#include <cmath>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include "common/distinct_uniform_int_distribution.h"
|
||||
|
||||
namespace aicpu {
|
||||
|
||||
RangeSampler::~RangeSampler() {}
|
||||
|
||||
void RangeSampler::SampleBatch(bool unique, std::vector<int64_t> &batch) const {}
|
||||
|
||||
void RangeSampler::SampleBatchGetExpectedCount(bool unique, int64_t seed, std::vector<int64_t> &batch,
|
||||
std::vector<float> &batch_expected_count, std::vector<int64_t> extras,
|
||||
std::vector<float> &extras_expected_count) const {
|
||||
SampleBatchGetExpectedCountAvoid(unique, seed, batch, batch_expected_count, extras, extras_expected_count,
|
||||
std::vector<int64_t>());
|
||||
}
|
||||
|
||||
namespace {
|
||||
static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
|
||||
if (num_tries == batch_size) {
|
||||
return p * batch_size;
|
||||
}
|
||||
return -std::expm1(num_tries * std::log1p(-p));
|
||||
}
|
||||
|
||||
template <class Collection>
|
||||
bool InsertIfNotPresent(Collection *const collection, const typename Collection::value_type &vt) {
|
||||
return collection->insert(vt).second;
|
||||
}
|
||||
|
||||
static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF);
|
||||
|
||||
} // namespace
|
||||
|
||||
void RangeSampler::SampleBatchGetExpectedCountAvoid(bool unique, int64_t seed, std::vector<int64_t> &batch,
|
||||
std::vector<float> &batch_expected_count,
|
||||
std::vector<int64_t> extras,
|
||||
std::vector<float> &extras_expected_count,
|
||||
std::vector<int64_t> avoided_values) const {
|
||||
const int batch_size = batch.size();
|
||||
int num_tries;
|
||||
if (range_ <= 0) {
|
||||
AICPU_LOGE("range_ must be greater than 0!");
|
||||
return;
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
int64_t seed_rng = (seed != 0) ? seed : rd();
|
||||
rnd_.seed(static_cast<uint64_t>(seed_rng));
|
||||
if (unique) {
|
||||
if (batch_size + avoided_values.size() > static_cast<size_t>(range_)) {
|
||||
AICPU_LOGE("the value should be less than range_: %d, but got %d", range_, batch_size + avoided_values.size());
|
||||
return;
|
||||
}
|
||||
std::unordered_set<int64_t> used(batch_size);
|
||||
used.insert(avoided_values.begin(), avoided_values.end());
|
||||
int num_picked = 0;
|
||||
num_tries = 0;
|
||||
while (num_picked < batch_size) {
|
||||
num_tries++;
|
||||
if (num_tries >= kint32max) {
|
||||
AICPU_LOGE("num_tries: %d should be less than kint32max: %d!", num_tries, kint32max);
|
||||
return;
|
||||
}
|
||||
int64_t value = Sample();
|
||||
if (InsertIfNotPresent(&used, value)) {
|
||||
batch[num_picked++] = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (avoided_values.size() != size_t{0}) {
|
||||
AICPU_LOGE("avoided_values only supported with unique=true");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
batch[i] = Sample();
|
||||
}
|
||||
num_tries = batch_size;
|
||||
}
|
||||
|
||||
if (!batch_expected_count.empty()) {
|
||||
if (batch_size != static_cast<int>(batch_expected_count.size())) {
|
||||
AICPU_LOGE("the size of extras_expected_count: %zu should be equal to batch_size: %d!",
|
||||
batch_expected_count.size(), batch_size);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
batch_expected_count[i] = ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
|
||||
}
|
||||
}
|
||||
if (extras.size() != extras_expected_count.size()) {
|
||||
AICPU_LOGE("the size of extras and extras_expected_count should be equal!");
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < extras.size(); i++) {
|
||||
extras_expected_count[i] = ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
|
||||
}
|
||||
}
|
||||
|
||||
UniformSampler::UniformSampler(int64_t range) : RangeSampler(range), inv_range_(1.0 / range) {}
|
||||
|
||||
int64_t UniformSampler::Sample() const {
|
||||
aicpu::distinct_uniform_int_distribution<> dis(0, range_ - 1);
|
||||
return dis.exec(&rnd_);
|
||||
}
|
||||
|
||||
float UniformSampler::Probability(int64_t value) const { return inv_range_; }
|
||||
|
||||
LogUniformSampler::LogUniformSampler(int64_t range) : RangeSampler(range), log_range_(log1p(range)) {}
|
||||
|
||||
int64_t LogUniformSampler::Sample() const {
|
||||
std::uniform_real_distribution<float> uni_real(0.0, 1.0);
|
||||
|
||||
const int64_t value = static_cast<int64_t>(exp(uni_real(rnd_) * log_range_)) - 1;
|
||||
if (value < 0) {
|
||||
AICPU_LOGE("value: %d should be >= 0", value);
|
||||
return 0;
|
||||
}
|
||||
|
||||
return value % range_;
|
||||
}
|
||||
|
||||
float LogUniformSampler::Probability(int64_t value) const { return (log((value + 2.0) / (value + 1.0))) / log_range_; }
|
||||
} // namespace aicpu
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef SRC_COMMON_RANGE_SAMPLER_H_
|
||||
#define SRC_COMMON_RANGE_SAMPLER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <random>
|
||||
#include "kernel_log.h"
|
||||
|
||||
namespace aicpu {
|
||||
class RangeSampler {
|
||||
public:
|
||||
explicit RangeSampler(int64_t range) : range_(range) {}
|
||||
virtual ~RangeSampler();
|
||||
|
||||
virtual int64_t Sample() const = 0;
|
||||
|
||||
virtual float Probability(int64_t value) const = 0;
|
||||
|
||||
void SampleBatch(bool unique, std::vector<int64_t> &batch) const;
|
||||
|
||||
void SampleBatchGetExpectedCount(bool unique, int64_t seed, std::vector<int64_t> &batch,
|
||||
std::vector<float> &batch_expected_count, std::vector<int64_t> extras,
|
||||
std::vector<float> &extras_expected_count) const;
|
||||
|
||||
virtual void SampleBatchGetExpectedCountAvoid(bool unique, int64_t seed, std::vector<int64_t> &batch,
|
||||
std::vector<float> &batch_expected_count, std::vector<int64_t> extras,
|
||||
std::vector<float> &extras_expected_count,
|
||||
std::vector<int64_t> avoided_values) const;
|
||||
|
||||
int64_t range() { return range_; }
|
||||
|
||||
protected:
|
||||
const int64_t range_;
|
||||
mutable std::mt19937 rnd_;
|
||||
};
|
||||
|
||||
class UniformSampler : public RangeSampler {
|
||||
public:
|
||||
explicit UniformSampler(int64_t range);
|
||||
|
||||
~UniformSampler() override {}
|
||||
|
||||
int64_t Sample() const override;
|
||||
|
||||
float Probability(int64_t value) const override;
|
||||
|
||||
private:
|
||||
const float inv_range_;
|
||||
};
|
||||
|
||||
class LogUniformSampler : public RangeSampler {
|
||||
public:
|
||||
explicit LogUniformSampler(int64_t range);
|
||||
|
||||
~LogUniformSampler() override {}
|
||||
|
||||
int64_t Sample() const override;
|
||||
|
||||
float Probability(int64_t value) const override;
|
||||
|
||||
private:
|
||||
const double log_range_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif // SRC_COMMON_RANGE_SAMPLER_H_
|
|
@ -178,6 +178,7 @@ constexpr auto kAddcdiv = "Addcdiv";
|
|||
constexpr auto kAddcmul = "Addcmul";
|
||||
constexpr auto kAdd = "Add";
|
||||
constexpr auto kTriu = "Triu";
|
||||
constexpr auto kUniformCandidateSampler = "UniformCandidateSampler";
|
||||
constexpr auto kExpand = "Expand";
|
||||
constexpr auto kExpandDims = "ExpandDims";
|
||||
constexpr auto kReshape = "Reshape";
|
||||
|
@ -320,6 +321,7 @@ const std::set<std::string> kCpuKernelBaseOps{kDropoutGenMaskOpName,
|
|||
kReshape,
|
||||
kFlatten,
|
||||
kSqueeze,
|
||||
kUniformCandidateSampler,
|
||||
kExpandDims};
|
||||
const std::set<std::string> kDynamicInputOps{kRaggedTensorToTensor,
|
||||
kSparseCross,
|
||||
|
|
|
@ -52,6 +52,7 @@ static const std::unordered_set<std::string> kAICpuOpNames = {kDropoutGenMaskOpN
|
|||
kReshapeOpName,
|
||||
kFlattenOpName,
|
||||
kSqueezeOpName,
|
||||
kUniformCandidateSamplerOpName,
|
||||
kExpandDimsOpName};
|
||||
static const std::unordered_set<std::string> kMigrateAicpuKernelOps = {
|
||||
mindspore::kAdaptiveAvgPool2DOpName,
|
||||
|
|
|
@ -27,6 +27,7 @@ uniform_candidate_sampler_op_info = AiCPURegOp("UniformCandidateSampler") \
|
|||
.attr("range_max", "int") \
|
||||
.attr("seed", "int") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue