From ee4e2db77e70a6229a6665aafc5d18ebaa9e030a Mon Sep 17 00:00:00 2001 From: TFbunny Date: Mon, 2 Nov 2020 16:31:44 -0500 Subject: [PATCH] rename UniformSampler to UniformCandidateSampler --- ...l.cu => uniform_candidate_sampler_impl.cu} | 12 +-- ...cuh => uniform_candidate_sampler_impl.cuh} | 10 +-- ...> uniform_candidate_sampler_gpu_kernel.cc} | 6 +- ...=> uniform_candidate_sampler_gpu_kernel.h} | 24 +++--- mindspore/nn/loss/loss.py | 2 +- mindspore/ops/operations/__init__.py | 4 +- mindspore/ops/operations/nn_ops.py | 6 +- ...y => test_uniform_candidate_sampler_op.py} | 85 ++++++++++++------- 8 files changed, 85 insertions(+), 64 deletions(-) rename mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/{uniform_sampler_impl.cu => uniform_candidate_sampler_impl.cu} (67%) rename mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/{uniform_sampler_impl.cuh => uniform_candidate_sampler_impl.cuh} (74%) rename mindspore/ccsrc/backend/kernel_compiler/gpu/nn/{uniform_sampler_gpu_kernel.cc => uniform_candidate_sampler_gpu_kernel.cc} (83%) rename mindspore/ccsrc/backend/kernel_compiler/gpu/nn/{uniform_sampler_gpu_kernel.h => uniform_candidate_sampler_gpu_kernel.h} (87%) rename tests/st/ops/gpu/{test_uniform_sampler_op.py => test_uniform_candidate_sampler_op.py} (51%) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu similarity index 67% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu index 9989b902745..d57ca7907d5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" template __global__ void AssignToOutput(const int size, const S prob_val, S *output_array) { @@ -24,13 +24,13 @@ __global__ void AssignToOutput(const int size, const S prob_val, S *output_array } template -void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, - S *sampled_expected_count, cudaStream_t cuda_stream) { +void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, + S *sampled_expected_count, cudaStream_t cuda_stream) { AssignToOutput<<>>(true_size, prob_val, true_expected_count); AssignToOutput<<>>(num_sampled, prob_val, sampled_expected_count); } -template void CalUniformSampler(const int true_size, const int num_sampled, const float prob_val, - float *true_expected_count, float *sampled_expected_count, - cudaStream_t cuda_stream); +template void CalUniformCandidateSampler(const int true_size, const int num_sampled, const float prob_val, + float *true_expected_count, float *sampled_expected_count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh similarity index 74% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh index 367c159333a..314c3e1c861 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh @@ -14,13 +14,13 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ #include #include "runtime/device/gpu/cuda_common.h" template -void CalUniformSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, - S *sampled_expected_count, cudaStream_t cuda_stream); +void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count, + S *sampled_expected_count, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_SAMPLER_IMPL_CUH_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc similarity index 83% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc rename to mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc index 56dd8723948..6de4aa2748d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.cc @@ -14,16 +14,16 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h" namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_TWO(UniformSampler, +MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler, KernelAttr() .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), - UniformSamplerGpuKernel, int, float) + UniformCandidateSamplerGpuKernel, int, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h similarity index 87% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h index fe45b0ebc42..ad6e29b3dbd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_sampler_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/uniform_candidate_sampler_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ #include #include @@ -23,16 +23,16 @@ #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/uniform_sampler_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh" namespace mindspore { namespace kernel { template -class UniformSamplerGpuKernel : public GpuKernel { +class UniformCandidateSamplerGpuKernel : public GpuKernel { public: - UniformSamplerGpuKernel() + UniformCandidateSamplerGpuKernel() : num_true_(0), num_sampled_(0), unique_(false), range_max_(0), input_size_(0), remove_accidental_hits_(false) {} - ~UniformSamplerGpuKernel() override = default; + ~UniformCandidateSamplerGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -61,20 +61,20 @@ class UniformSamplerGpuKernel : public GpuKernel { CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), "cudaMemcpyAsync sampled_candidates failed"); - CalUniformSampler(static_cast(input_size_), num_sampled_, value, true_expected_count, sampled_expected_count, - reinterpret_cast(stream_ptr)); + CalUniformCandidateSampler(static_cast(input_size_), num_sampled_, value, true_expected_count, + sampled_expected_count, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformSampler needs 1 input."; + MS_LOG(ERROR) << "Input number is " << input_num << ", but UniformCandidateSampler needs 1 input."; return false; } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 3) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformSampler has 3 outputs."; + MS_LOG(ERROR) << "Output number is " << output_num << ", but UniformCandidateSampler has 3 outputs."; return false; } // getting attrs @@ -88,7 +88,7 @@ class UniformSamplerGpuKernel : public GpuKernel { generator_.seed(seed); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); if (input_shape.size() != 2) { - MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformSampler supports only 2-D inputs."; + MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but UniformCandidateSampler supports only 2-D inputs."; return false; } input_size_ = input_shape[0] * input_shape[1]; @@ -160,4 +160,4 @@ class UniformSamplerGpuKernel : public GpuKernel { }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_SAMPLER_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_UNIFORM_CANDIDATE_SAMPLER_GPU_KERNEL_H_ diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 1c834bf172c..a7ee4fd2da4 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -303,7 +303,7 @@ class SampledSoftmaxLoss(_Loss): self.sampled_values = sampled_values self.remove_accidental_hits = remove_accidental_hits self.seed = seed - self.sampler = P.UniformSampler( + self.sampler = P.UniformCandidateSampler( num_true, num_sampled, True, diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 7cc3ab6b002..a040eb26669 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -79,7 +79,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, - ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformSampler) + ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, UniformCandidateSampler) from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, @@ -375,7 +375,7 @@ __all__ = [ "ApproximateEqual", "InplaceUpdate", "InTopK", - "UniformSampler", + "UniformCandidateSampler", "LRN", "Mod", "PopulationCount", diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a50f04efff8..9cfd2c2c5c8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5820,7 +5820,7 @@ class LRN(PrimitiveWithInfer): return x_shape -class UniformSampler(PrimitiveWithInfer): +class UniformCandidateSampler(PrimitiveWithInfer): r""" Uniform candidate sampler. @@ -5848,14 +5848,14 @@ class UniformSampler(PrimitiveWithInfer): sampled_candidates. Shape: (num_sampled, ). Examples: - >>> sampler = P.UniformSampler(1, 3, False, 4) + >>> sampler = P.UniformCandidateSampler(1, 3, False, 4) >>> SampledCandidates, TrueExpectedCount, SampledExpectedCount = sampler(Tensor(np.array([[1],[3],[4],[6], [3]], dtype=np.int32))) [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75] """ @prim_attr_register def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False): - """Initialize UniformSampler""" + """Initialize UniformCandidateSampler""" validator.check_value_type("num_true", num_true, [int], self.name) validator.check_value_type("num_sampled", num_sampled, [int], self.name) validator.check_value_type("unique", unique, [bool], self.name) diff --git a/tests/st/ops/gpu/test_uniform_sampler_op.py b/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py similarity index 51% rename from tests/st/ops/gpu/test_uniform_sampler_op.py rename to tests/st/ops/gpu/test_uniform_candidate_sampler_op.py index a3650a27a8c..6e6628f3b43 100644 --- a/tests/st/ops/gpu/test_uniform_sampler_op.py +++ b/tests/st/ops/gpu/test_uniform_candidate_sampler_op.py @@ -21,45 +21,55 @@ from mindspore.ops import operations as P import mindspore.nn as nn import mindspore.context as context -class UniformSamplerNet(nn.Cell): +class UniformCandidateSamplerNet(nn.Cell): def __init__(self, num_true, num_sampled, unique, range_max): - super(UniformSamplerNet, self).__init__() - self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max) + super(UniformCandidateSamplerNet, self).__init__() + self.sampler = P.UniformCandidateSampler(num_true, num_sampled, + unique, range_max) def construct(self, x): return self.sampler(x) -def uniform_sampler(x, num_true, num_sampled, unique, range_max): - uniform_sampler_net = UniformSamplerNet(num_true, num_sampled, unique, range_max) - out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) +def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max): + uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true, + num_sampled, + unique, + range_max) + out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) return out1.shape, out2.shape, out3.shape -class UniformSamplerHitNet(nn.Cell): +class UniformCandidateSamplerHitNet(nn.Cell): def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits): - super(UniformSamplerHitNet, self).__init__() - self.sampler = P.UniformSampler(num_true, num_sampled, unique, range_max, seed=seed, - remove_accidental_hits=remove_accidental_hits) + super(UniformCandidateSamplerHitNet, self).__init__() + self.sampler = P.UniformCandidateSampler(num_true, num_sampled, unique, + range_max, seed=seed, + remove_accidental_hits=remove_accidental_hits) def construct(self, x): return self.sampler(x) -def uniform_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, - remove_accidental_hits): - uniform_sampler_net = UniformSamplerHitNet(num_true, num_sampled, unique, range_max, - seed, remove_accidental_hits) - out1, out2, out3 = uniform_sampler_net(Tensor(x.astype(np.int32))) +def uniform_candidate_sampler_hit(x, num_true, num_sampled, unique, range_max, seed, + remove_accidental_hits): + uniform_candidate_sampler_net = UniformCandidateSamplerHitNet(num_true, + num_sampled, + unique, + range_max, + seed, + remove_accidental_hits) + out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32))) return out1, out2, out3 @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_unique_1_true(): +def test_uniform_candidate_sampler_unique_1_true(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, True, 4) + ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]), + 1, 3, True, 4) expected_1 = (3,) expected_2 = (5, 1) expected_3 = (3,) @@ -70,9 +80,10 @@ def test_uniform_sampler_unique_1_true(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_not_unique_1_true(): +def test_uniform_candidate_sampler_not_unique_1_true(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.array([[1], [3], [4], [6], [3]]), 1, 3, False, 4) + ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1], [3], [4], [6], [3]]), + 1, 3, False, 4) expected_1 = (3,) expected_2 = (5, 1) expected_3 = (3,) @@ -83,9 +94,11 @@ def test_uniform_sampler_not_unique_1_true(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_unique_2_true(): +def test_uniform_candidate_sampler_unique_2_true(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, True, 4) + ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2], [4, 2], + [6, 2], [3, 2]]), + 2, 3, True, 4) expected_1 = (3,) expected_2 = (5, 2) expected_3 = (3,) @@ -96,9 +109,12 @@ def test_uniform_sampler_unique_2_true(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_not_unique_2_true(): +def test_uniform_candidate_sampler_not_unique_2_true(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.array([[1, 2], [3, 2], [4, 2], [6, 2], [3, 2]]), 2, 3, False, 4) + ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[1, 2], [3, 2], + [4, 2], [6, 2], + [3, 2]]), + 2, 3, False, 4) expected_1 = (3,) expected_2 = (5, 2) expected_3 = (3,) @@ -109,10 +125,14 @@ def test_uniform_sampler_not_unique_2_true(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_large(): +def test_uniform_candidate_sampler_large(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.array([[12221, 41414], [3312, 5125152], [3312454, 51252], - [65125, 225125], [35125, 5125122]]), 2, 5, False, 100) + ms1, ms2, ms3 = uniform_candidate_sampler(np.array([[12221, 41414], + [3312, 5125152], + [3312454, 51252], + [65125, 225125], + [35125, 5125122]]), + 2, 5, False, 100) expected_1 = (5,) expected_2 = (5, 2) expected_3 = (5,) @@ -124,9 +144,10 @@ def test_uniform_sampler_large(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_large_random(): +def test_uniform_candidate_sampler_large_random(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, ms2, ms3 = uniform_sampler(np.arange(2142).reshape(34, 63), 63, 10, False, 12) + ms1, ms2, ms3 = uniform_candidate_sampler(np.arange(2142).reshape(34, 63), + 63, 10, False, 12) expected_1 = (10,) expected_2 = (34, 63) expected_3 = (10,) @@ -138,9 +159,9 @@ def test_uniform_sampler_large_random(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_unique_1_true_hit(): +def test_uniform_candidate_sampler_unique_1_true_hit(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) + ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, False) expected_1 = np.array([0, 3, 1]) np.testing.assert_array_equal(ms1.asnumpy(), expected_1) @@ -148,8 +169,8 @@ def test_uniform_sampler_unique_1_true_hit(): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_uniform_sampler_unique_1_true_no_hit(): +def test_uniform_candidate_sampler_unique_1_true_no_hit(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - ms1, _, _ = uniform_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) + ms1, _, _ = uniform_candidate_sampler_hit(np.array([[1]]), 1, 3, True, 4, 1, True) expected_1 = np.array([0, 3, 2]) np.testing.assert_array_equal(ms1.asnumpy(), expected_1)