add GPU UniformSampler

This commit is contained in:
TFbunny 2020-10-21 12:20:48 -04:00
parent 9c79b9d712
commit e18a78feb9
7 changed files with 406 additions and 1 deletions

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
template <typename S>
__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_array[pos] = prob_val;
}
}
template <typename S>
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream) {
AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count);
AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val,
sampled_expected_count);
}
template void CalUniformSampler<float>(const int true_size, const int num_sampled, const float prob_val,
float *true_expected_count, float *sampled_expected_count,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename S>
void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(UniformSampler,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
UniformSamplerGpuKernel, int, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,144 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_
#include <cmath>
#include <set>
#include <vector>
#include <random>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class UniformSamplerGpuKernel : public GpuKernel {
public:
UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {}
~UniformSamplerGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspaces);
T *sampled_candidates = GetDeviceAddress<T>(outputs, 0);
S *true_expected_count = GetDeviceAddress<S>(outputs, 1);
S *sampled_expected_count = GetDeviceAddress<S>(outputs, 2);
int counter = Sampling();
float prob = Probability();
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
S value = ApproximateExpectedCount(prob, num_sampled_, counter);
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync sampled_candidates failed");
CalUniformSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs.";
return false;
}
// getting attrs
num_true_ = GetAttr<int>(kernel_node, "num_true");
num_sampled_ = GetAttr<int>(kernel_node, "num_sampled");
unique_ = GetAttr<bool>(kernel_node, "unique");
range_max_ = GetAttr<int>(kernel_node, "range_max");
int seed = GetAttr<int>(kernel_node, "seed");
if (seed == 0) seed = time(NULL);
generator_.seed(seed);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape.size() != 2) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs.";
return false;
}
input_size_ = input_shape[0] * input_shape[1];
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(num_sampled_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(S));
output_size_list_.push_back(num_sampled_ * sizeof(S));
}
int Sampling() {
int counter = 0;
int tmp;
int picked;
std::set<int> set_container;
// pick between [0, range_max_-1]
std::uniform_int_distribution<int> distribution(0, range_max_ - 1);
sampled_candidates_.clear();
if (unique_) {
picked = 0;
while (picked < num_sampled_) {
tmp = distribution(generator_);
counter++;
if (set_container.find(tmp) == set_container.end()) {
set_container.insert(tmp);
sampled_candidates_.push_back(tmp);
picked++;
}
}
} else {
for (int i = 0; i < num_sampled_; i++) {
sampled_candidates_.push_back(distribution(generator_));
}
counter = num_sampled_;
}
return counter;
}
S Probability() { return static_cast<S>(1.0f / range_max_); }
S ApproximateExpectedCount(S p, int sampled_size, int counter) {
if (sampled_size == counter) return p * sampled_size;
return -std::expm1(counter * std::log1p(-p));
}
private:
int num_true_;
int num_sampled_;
bool unique_;
int range_max_;
size_t input_size_;
std::default_random_engine generator_;
std::vector<int> sampled_candidates_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_

View File

@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
FusedSparseFtrl, FusedSparseProximalAdagrad, FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler)
from . import _quant_ops from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
@ -373,6 +373,7 @@ __all__ = [
"ApproximateEqual", "ApproximateEqual",
"InplaceUpdate", "InplaceUpdate",
"InTopK", "InTopK",
"UniformSampler",
"LRN", "LRN",
"Mod", "Mod",
"PopulationCount", "PopulationCount",

View File

@ -5730,3 +5730,56 @@ class LRN(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
return x_shape return x_shape
class UniformSampler(PrimitiveWithInfer):
r"""
Uniform candidate sampler.
This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
If unique=True, candidates are drawn without replacement, else unique=False with replacement.
Args:
num_true (int): The number of target classes in each training example.
num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape
of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
unique (bool): Whether all sampled classes in a batch are unique.
range_max (int): The number of possible classes.
seed (int): Random seed, must be non-negative. Default: 0.
Inputs:
true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true).
Outputs:
A tuple of 3 tensors.
sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ).
true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes.
Shape: (batch_size, num_true).
sampled_expected_count: (float): The expected counts under the sampling distribution of each of
sampled_candidates. Shape: (num_sampled, ).
Examples:
>>> sampler = P.UniformSampler(1, 3, False, 4)
>>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6],
[3]], dtype=np.int32)))
[1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
"""
@prim_attr_register
def __init__(self, num_true, num_sampled, unique, range_max, seed=0):
"""Initialize UniformSampler"""
validator.check_value_type("num_true", num_true, [int], self.name)
validator.check_value_type("num_sampled", num_sampled, [int], self.name)
validator.check_value_type("unique", unique, [bool], self.name)
validator.check_value_type("range_max", range_max, [int], self.name)
validator.check_value_type("seed", seed, [int], self.name)
validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
if unique:
validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
self.num_sampled = num_sampled
def infer_dtype(self, true_classes_type):
return (true_classes_type, mstype.float32, mstype.float32)
def infer_shape(self, true_classes_shape):
return ([self.num_sampled], true_classes_shape, [self.num_sampled])

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
class UniformSamplerNet(nn.Cell):
def __init__(self, num_true, num_sampled, unique, range_max):
super(UniformSamplerNet, self).__init__()
self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max)
def construct(self, x):
return self.sampler(x)
def uniform_sampler(x, num_true, num_sampled, unique, range_max):
uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max)
out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32)))
return out1.shape, out2.shape, out3.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_unique_1_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4)
expected_1 = (3,)
expected_2 = (5, 1)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_not_unique_1_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4)
expected_1 = (3,)
expected_2 = (5, 1)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_unique_2_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4)
expected_1 = (3,)
expected_2 = (5, 2)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_not_unique_2_true():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4)
expected_1 = (3,)
expected_2 = (5, 2)
expected_3 = (3,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_large():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252],
[65125, 225125], [35125, 5125122]]), 2, 5, False, 100)
expected_1 = (5,)
expected_2 = (5, 2)
expected_3 = (5,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_sampler_large_random():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12)
expected_1 = (10,)
expected_2 = (34, 63)
expected_3 = (10,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)