From e18a78feb9833a30d1c4d4ba3c94c357744f805a Mon Sep 17 00:00:00 2001 From: TFbunny Date: Wed, 21 Oct 2020 12:20:48 -0400 Subject: [PATCH] add GPU UniformSampler --- .../gpu/cuda_impl/uniform_sampler_impl.cu | 36 +++++ .../gpu/cuda_impl/uniform_sampler_impl.cuh | 26 ++++ .../gpu/nn/uniform_sampler_gpu_kernel.cc | 29 ++++ .../gpu/nn/uniform_sampler_gpu_kernel.h | 144 ++++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/nn_ops.py | 53 +++++++ tests/st/ops/gpu/test_uniform_sampler_op.py | 116 ++++++++++++++ 7 files changed, 406 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_uniform_sampler_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu new file mode 100644 index 00000000000..9989b902745 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" + +template +__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output_array[pos] = prob_val; + } +} + +template +void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, + S *sampled_expected_count, cudaStream_t cuda_stream) { + AssignToOutput<<>>(true_size, prob_val, true_expected_count); + AssignToOutput<<>>(num_sampled, prob_val, + sampled_expected_count); +} + +template void CalUniformSampler(const int true_size, const int num_sampled, const float prob_val, + float *true_expected_count, float *sampled_expected_count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh new file mode 100644 index 00000000000..367c159333a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, + S *sampled_expected_count, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc new file mode 100644 index 00000000000..56dd8723948 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(UniformSampler, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + UniformSamplerGpuKernel, int, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h new file mode 100644 index 00000000000..a1fc9ea4aaa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class UniformSamplerGpuKernel : public GpuKernel { + public: + UniformSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0) {} + ~UniformSamplerGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspaces); + T *sampled_candidates = GetDeviceAddress(outputs, 0); + S *true_expected_count = GetDeviceAddress(outputs, 1); + S *sampled_expected_count = GetDeviceAddress(outputs, 2); + int counter = Sampling(); + float prob = Probability(); + size_t sampled_candidates_size = num_sampled_ * sizeof(T); + S value = ApproximateExpectedCount(prob, num_sampled_, counter); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync sampled_candidates failed"); + CalUniformSampler(static_cast(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 3) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs."; + return false; + } + // getting attrs + num_true_ = GetAttr(kernel_node, "num_true"); + num_sampled_ = GetAttr(kernel_node, "num_sampled"); + unique_ = GetAttr(kernel_node, "unique"); + range_max_ = GetAttr(kernel_node, "range_max"); + int seed = GetAttr(kernel_node, "seed"); + if (seed == 0) seed = time(NULL); + generator_.seed(seed); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (input_shape.size() != 2) { + MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs."; + return false; + } + input_size_ = input_shape[0] * input_shape[1]; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(num_sampled_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(S)); + output_size_list_.push_back(num_sampled_ * sizeof(S)); + } + + int Sampling() { + int counter = 0; + int tmp; + int picked; + std::set set_container; + // pick between [0, range_max_-1] + std::uniform_int_distribution distribution(0, range_max_ - 1); + sampled_candidates_.clear(); + if (unique_) { + picked = 0; + while (picked < num_sampled_) { + tmp = distribution(generator_); + counter++; + if (set_container.find(tmp) == set_container.end()) { + set_container.insert(tmp); + sampled_candidates_.push_back(tmp); + picked++; + } + } + } else { + for (int i = 0; i < num_sampled_; i++) { + sampled_candidates_.push_back(distribution(generator_)); + } + counter = num_sampled_; + } + return counter; + } + + S Probability() { return static_cast(1.0f / range_max_); } + + S ApproximateExpectedCount(S p, int sampled_size, int counter) { + if (sampled_size == counter) return p * sampled_size; + return -std::expm1(counter * std::log1p(-p)); + } + + private: + int num_true_; + int num_sampled_; + bool unique_; + int range_max_; + size_t input_size_; + std::default_random_engine generator_; + std::vector sampled_candidates_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a7e283dc015..ed075f03757 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, - ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) + ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler) from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, @@ -373,6 +373,7 @@ __all__ = [ "ApproximateEqual", "InplaceUpdate", "InTopK", + "UniformSampler", "LRN", "Mod", "PopulationCount", diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9426633642a..ce69bce8aaa 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5730,3 +5730,56 @@ class LRN(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) return x_shape + + +class UniformSampler(PrimitiveWithInfer): + r""" + Uniform candidate sampler. + + This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution. + If unique=True, candidates are drawn without replacement, else unique=False with replacement. + + Args: + num_true (int): The number of target classes in each training example. + num_sampled (int): The number of classes to randomly sample. The **sampled_candidates** will have a shape + of num_sampled. If unique=True, num_sampled must be less than or equal to range_max. + unique (bool): Whether all sampled classes in a batch are unique. + range_max (int): The number of possible classes. + seed (int): Random seed, must be non-negative. Default: 0. + + Inputs: + true_classes (int): A tensor. The target classes with a tensor shape of (batch_size, num_true). + + Outputs: + A tuple of 3 tensors. + sampled_candidates: (int): The sampled_candidates is independent of the true classes. Shape: (num_sampled, ). + true_expected_count: (float): The expected counts under the sampling distribution of each of true_classes. + Shape: (batch_size, num_true). + sampled_expected_count: (float): The expected counts under the sampling distribution of each of + sampled_candidates. Shape: (num_sampled, ). + + Examples: + >>> sampler = P.UniformSampler(1, 3, False, 4) + >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], + [3]], dtype=np.int32))) + [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] + """ + @prim_attr_register + def __init__(self, num_true, num_sampled, unique, range_max, seed=0): + """Initialize UniformSampler""" + validator.check_value_type("num_true", num_true, [int], self.name) + validator.check_value_type("num_sampled", num_sampled, [int], self.name) + validator.check_value_type("unique", unique, [bool], self.name) + validator.check_value_type("range_max", range_max, [int], self.name) + validator.check_value_type("seed", seed, [int], self.name) + validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name) + if unique: + validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name) + validator.check("value of seed", seed, '', 0, Rel.GE, self.name) + self.num_sampled = num_sampled + + def infer_dtype(self, true_classes_type): + return (true_classes_type, mstype.float32, mstype.float32) + + def infer_shape(self, true_classes_shape): + return ([self.num_sampled], true_classes_shape, [self.num_sampled]) diff --git a/tests/st/ops/gpu/test_uniform_sampler_op.py b/tests/st/ops/gpu/test_uniform_sampler_op.py new file mode 100644 index 00000000000..f0625d8742d --- /dev/null +++ b/tests/st/ops/gpu/test_uniform_sampler_op.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +class UniformSamplerNet(nn.Cell): + def __init__(self, num_true, num_sampled, unique, range_max): + super(UniformSamplerNet, self).__init__() + self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max) + + def construct(self, x): + return self.sampler(x) + + +def uniform_sampler(x, num_true, num_sampled, unique, range_max): + uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max) + out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) + return out1.shape, out2.shape, out3.shape + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_unique_1_true(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4) + expected_1 = (3,) + expected_2 = (5, 1) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_not_unique_1_true(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4) + expected_1 = (3,) + expected_2 = (5, 1) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_unique_2_true(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4) + expected_1 = (3,) + expected_2 = (5, 2) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_not_unique_2_true(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4) + expected_1 = (3,) + expected_2 = (5, 2) + expected_3 = (3,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_large(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252], + [65125, 225125], [35125, 5125122]]), 2, 5, False, 100) + expected_1 = (5,) + expected_2 = (5, 2) + expected_3 = (5,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_uniform_sampler_large_random(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12) + expected_1 = (10,) + expected_2 = (34, 63) + expected_3 = (10,) + np.testing.assert_array_equal(ms1, expected_1) + np.testing.assert_array_equal(ms2, expected_2) + np.testing.assert_array_equal(ms3, expected_3)