add GPU UniformSampler
This commit is contained in:
parent
9c79b9d712
commit
e18a78feb9
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
|
||||
|
||||
template <typename S>
|
||||
__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
output_array[pos] = prob_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
|
||||
S *sampled_expected_count, cudaStream_t cuda_stream) {
|
||||
AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count);
|
||||
AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val,
|
||||
sampled_expected_count);
|
||||
}
|
||||
|
||||
template void CalUniformSampler<float>(const int true_size, const int num_sampled, const float prob_val,
|
||||
float *true_expected_count, float *sampled_expected_count,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename S>
|
||||
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
|
||||
S *sampled_expected_count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(UniformSampler,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UniformSamplerGpuKernel, int, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class UniformSamplerGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {}
|
||||
~UniformSamplerGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspaces);
|
||||
T *sampled_candidates = GetDeviceAddress<T>(outputs, 0);
|
||||
S *true_expected_count = GetDeviceAddress<S>(outputs, 1);
|
||||
S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2);
|
||||
int counter = Sampling();
|
||||
float prob = Probability();
|
||||
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
|
||||
S value = ApproximateExpectedCount(prob, num_sampled_, counter);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync sampled_candidates failed");
|
||||
CalUniformSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs.";
|
||||
return false;
|
||||
}
|
||||
// getting attrs
|
||||
num_true_ = GetAttr<int>(kernel_node, "num_true");
|
||||
num_sampled_ = GetAttr<int>(kernel_node, "num_sampled");
|
||||
unique_ = GetAttr<bool>(kernel_node, "unique");
|
||||
range_max_ = GetAttr<int>(kernel_node, "range_max");
|
||||
int seed = GetAttr<int>(kernel_node, "seed");
|
||||
if (seed == 0) seed = time(NULL);
|
||||
generator_.seed(seed);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (input_shape.size() != 2) {
|
||||
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs.";
|
||||
return false;
|
||||
}
|
||||
input_size_ = input_shape[0] * input_shape[1];
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(num_sampled_ * sizeof(T));
|
||||
output_size_list_.push_back(input_size_ * sizeof(S));
|
||||
output_size_list_.push_back(num_sampled_ * sizeof(S));
|
||||
}
|
||||
|
||||
int Sampling() {
|
||||
int counter = 0;
|
||||
int tmp;
|
||||
int picked;
|
||||
std::set<int> set_container;
|
||||
// pick between [0, range_max_-1]
|
||||
std::uniform_int_distribution<int> distribution(0, range_max_ - 1);
|
||||
sampled_candidates_.clear();
|
||||
if (unique_) {
|
||||
picked = 0;
|
||||
while (picked < num_sampled_) {
|
||||
tmp = distribution(generator_);
|
||||
counter++;
|
||||
if (set_container.find(tmp) == set_container.end()) {
|
||||
set_container.insert(tmp);
|
||||
sampled_candidates_.push_back(tmp);
|
||||
picked++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_sampled_; i++) {
|
||||
sampled_candidates_.push_back(distribution(generator_));
|
||||
}
|
||||
counter = num_sampled_;
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
S Probability() { return static_cast<S>(1.0f / range_max_); }
|
||||
|
||||
S ApproximateExpectedCount(S p, int sampled_size, int counter) {
|
||||
if (sampled_size == counter) return p * sampled_size;
|
||||
return -std::expm1(counter * std::log1p(-p));
|
||||
}
|
||||
|
||||
private:
|
||||
int num_true_;
|
||||
int num_sampled_;
|
||||
bool unique_;
|
||||
int range_max_;
|
||||
size_t input_size_;
|
||||
std::default_random_engine generator_;
|
||||
std::vector<int> sampled_candidates_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
|
|
@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
|
|||
FusedSparseFtrl, FusedSparseProximalAdagrad,
|
||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
||||
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler)
|
||||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||
|
@ -373,6 +373,7 @@ __all__ = [
|
|||
"ApproximateEqual",
|
||||
"InplaceUpdate",
|
||||
"InTopK",
|
||||
"UniformSampler",
|
||||
"LRN",
|
||||
"Mod",
|
||||
"PopulationCount",
|
||||
|
|
|
@ -5730,3 +5730,56 @@ class LRN(PrimitiveWithInfer):
|
|||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
|
||||
class UniformSampler(PrimitiveWithInfer):
|
||||
r"""
|
||||
Uniform candidate sampler.
|
||||
|
||||
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
|
||||
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
|
||||
|
||||
Args:
|
||||
num_true (int): The number of target classes in each training example.
|
||||
num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape
|
||||
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
|
||||
unique (bool): Whether all sampled classes in a batch are unique.
|
||||
range_max (int): The number of possible classes.
|
||||
seed (int): Random seed, must be non-negative. Default: 0.
|
||||
|
||||
Inputs:
|
||||
true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true).
|
||||
|
||||
Outputs:
|
||||
A tuple of 3 tensors.
|
||||
sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).
|
||||
true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes.
|
||||
Shape: (batch_size, num_true).
|
||||
sampled_expected_count: (float): The expected counts under the sampling distribution of each of
|
||||
sampled_candidates. Shape: (num_sampled, ).
|
||||
|
||||
Examples:
|
||||
>>> sampler = P.UniformSampler(1, 3, False, 4)
|
||||
>>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6],
|
||||
[3]], dtype=np.int32)))
|
||||
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, num_true, num_sampled, unique, range_max, seed=0):
|
||||
"""Initialize UniformSampler"""
|
||||
validator.check_value_type("num_true", num_true, [int], self.name)
|
||||
validator.check_value_type("num_sampled", num_sampled, [int], self.name)
|
||||
validator.check_value_type("unique", unique, [bool], self.name)
|
||||
validator.check_value_type("range_max", range_max, [int], self.name)
|
||||
validator.check_value_type("seed", seed, [int], self.name)
|
||||
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
|
||||
if unique:
|
||||
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
|
||||
validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
|
||||
self.num_sampled = num_sampled
|
||||
|
||||
def infer_dtype(self, true_classes_type):
|
||||
return (true_classes_type, mstype.float32, mstype.float32)
|
||||
|
||||
def infer_shape(self, true_classes_shape):
|
||||
return ([self.num_sampled], true_classes_shape, [self.num_sampled])
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
class UniformSamplerNet(nn.Cell):
|
||||
def __init__(self, num_true, num_sampled, unique, range_max):
|
||||
super(UniformSamplerNet, self).__init__()
|
||||
self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max)
|
||||
|
||||
def construct(self, x):
|
||||
return self.sampler(x)
|
||||
|
||||
|
||||
def uniform_sampler(x, num_true, num_sampled, unique, range_max):
|
||||
uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max)
|
||||
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
|
||||
return out1.shape, out2.shape, out3.shape
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_unique_1_true():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 1)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_not_unique_1_true():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 1)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_unique_2_true():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 2)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_not_unique_2_true():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4)
|
||||
expected_1 = (3,)
|
||||
expected_2 = (5, 2)
|
||||
expected_3 = (3,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_large():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252],
|
||||
[65125, 225125], [35125, 5125122]]), 2, 5, False, 100)
|
||||
expected_1 = (5,)
|
||||
expected_2 = (5, 2)
|
||||
expected_3 = (5,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_uniform_sampler_large_random():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12)
|
||||
expected_1 = (10,)
|
||||
expected_2 = (34, 63)
|
||||
expected_3 = (10,)
|
||||
np.testing.assert_array_equal(ms1, expected_1)
|
||||
np.testing.assert_array_equal(ms2, expected_2)
|
||||
np.testing.assert_array_equal(ms3, expected_3)
|
Loading…
Reference in New Issue