From 1215e4e8d37a539b7705e48bc3ff428b334da72e Mon Sep 17 00:00:00 2001 From: fan-jibin Date: Sun, 21 Aug 2022 19:00:30 +0800 Subject: [PATCH] add randomshuffle_gpu --- .../cuda_impl/cuda_ops/random_shuffle_impl.cu | 79 +++++++ .../cuda_ops/random_shuffle_impl.cuh | 26 +++ .../random/random_shuffle_gpu_kernel.cc | 198 ++++++++++++++++++ .../kernel/random/random_shuffle_gpu_kernel.h | 80 +++++++ .../python/mindspore/ops/_vmap/vmap_base.py | 3 +- .../mindspore/ops/function/random_func.py | 2 +- .../mindspore/ops/operations/random_ops.py | 2 +- tests/st/ops/gpu/test_random_shuffle.py | 139 ++++++++++++ 8 files changed, 526 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_random_shuffle.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cu new file mode 100644 index 00000000000..2f5529dcc63 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cu @@ -0,0 +1,79 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "random_shuffle_impl.cuh" +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "include/cuda_fp16.h" + +template +using Complex = mindspore::utils::Complex; + +#define SHUFFLE_DECLARE(type) \ + template CUDA_LIB_EXPORT void ScalarShuffle(const int64_t size, const int *perm, const type *input, \ + type *output, const uint32_t device_id, cudaStream_t cuda_stream); \ + template CUDA_LIB_EXPORT void TensorShuffle(const int64_t shuffle_size, const int64_t inner_size, \ + const int *perm, const type *input, type *output, \ + const uint32_t device_id, cudaStream_t cuda_stream); + +template +__global__ void ScalarShuffleKernel(const int64_t size, const int *perm, const T *input, T *output) { + for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = input[perm[pos]]; + } +} + +template +__global__ void TensorShuffleKernel(const int64_t shuffle_size, const int64_t inner_size, const int *perm, + const T *input, T *output) { + for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < shuffle_size * inner_size; + pos += blockDim.x * gridDim.x) { + int64_t row = pos / inner_size; + int64_t col = pos % inner_size; + int64_t output_offset = perm[row] * inner_size + col; + output[output_offset] = input[pos]; + } +} + +template +void ScalarShuffle(const int64_t size, const int *perm, const T *input, T *output, const uint32_t device_id, + cudaStream_t cuda_stream) { + ScalarShuffleKernel<<>>(size, perm, input, + output); +} + +template +void TensorShuffle(const int64_t shuffle_size, const int64_t inner_size, const int *perm, const T *input, T *output, + const uint32_t device_id, cudaStream_t cuda_stream) { + int64_t total_size = shuffle_size * inner_size; + TensorShuffleKernel<<>>( + shuffle_size, inner_size, perm, input, output); +} + +SHUFFLE_DECLARE(half); +SHUFFLE_DECLARE(float); +SHUFFLE_DECLARE(double); +SHUFFLE_DECLARE(int8_t); +SHUFFLE_DECLARE(int16_t); +SHUFFLE_DECLARE(int32_t); +SHUFFLE_DECLARE(int64_t); +SHUFFLE_DECLARE(uint8_t); +SHUFFLE_DECLARE(uint16_t); +SHUFFLE_DECLARE(uint32_t); +SHUFFLE_DECLARE(uint64_t); +SHUFFLE_DECLARE(bool); +SHUFFLE_DECLARE(Complex); +SHUFFLE_DECLARE(Complex); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh new file mode 100644 index 00000000000..d2b05a80a36 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT void ScalarShuffle(const int64_t size, const int *perm, const T *input, T *output, + const uint32_t device_id, cudaStream_t cuda_stream); +template +CUDA_LIB_EXPORT void TensorShuffle(const int64_t shuffle_size, const int64_t inner_size, const int *perm, + const T *input, T *output, const uint32_t device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.cc new file mode 100644 index 00000000000..2cb65d13b07 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h" +#include +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "utils/log_adapter.h" +#include "kernel/common_utils.h" +#include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh" + +namespace mindspore { +namespace kernel { +namespace { +template +using Complex = mindspore::utils::Complex; + +constexpr size_t kRandomShuffleInputsNum = 1; +constexpr size_t kRandomShuffleOutputsNum = 1; +constexpr size_t kScalarShapeSize = 1; +} // namespace + +bool RandomShuffleGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + auto kernel_ptr = std::make_shared(base_operator->GetPrim()); + batch_rank_ = LongToSize(kernel_ptr->get_batch_rank()); + auto seed = kernel_ptr->get_seed(); + auto seed2 = kernel_ptr->get_seed2(); + if (seed == 0 && seed2 == 0) { + std::random_device rd; + std::mt19937_64 gen(rd()); + seed = gen(); + } else { + seed = (seed == 0) ? seed2 : seed; + } + generator_.seed(seed); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "RandomShuffle does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int RandomShuffleGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRandomShuffleInputsNum, kernel_name_); + + int ret = KernelMod::Resize(base_operator, inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + + MS_EXCEPTION_IF_NULL(inputs[0]); + input_shape_ = inputs[0]->GetShapeVector(); + if (!input_shape_.empty() && batch_rank_ >= input_shape_.size()) { + MS_LOG(ERROR) << "For '" << kernel_name_ + << "', the batch_rank should be less than input shape, but got batch_rank: " << batch_rank_ + << ", input shape: " << input_shape_; + return KRET_RESIZE_FAILED; + } + + outer_size_ = 1; + for (size_t i = 0; i < batch_rank_; i++) { + outer_size_ *= input_shape_[i]; + } + inner_size_ = 1; + for (size_t j = batch_rank_ + 1; j < input_shape_.size(); j++) { + inner_size_ *= input_shape_[j]; + } + + if (input_shape_.size() > batch_rank_) { + shuffle_size_ = LongToSize(input_shape_[batch_rank_]); + } else { + shuffle_size_ = 1; + } + + workspace_size_list_.push_back(sizeof(int) * shuffle_size_ * outer_size_); + return ret; +} + +std::vector RandomShuffleGpuKernelMod::GetShuffleIndex() { + std::vector perm(shuffle_size_); + int n = 0; + std::generate(perm.begin(), perm.end(), [&n] { return n++; }); + std::shuffle(perm.begin(), perm.end(), generator_); + return perm; +} + +template +bool RandomShuffleGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRandomShuffleInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kRandomShuffleOutputsNum, kernel_name_); + + auto *input_addr = GetDeviceAddress(inputs, 0); + auto *workspace_addr = GetDeviceAddress(workspace, 0); + auto *output_addr = GetDeviceAddress(outputs, 0); + if (input_shape_.empty() || input_shape_[batch_rank_] <= 1) { + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(cuda_stream_)), + "RandomShuffle cudaMemcpy failed."); + return true; + } + + if (input_shape_.size() <= batch_rank_ + kScalarShapeSize) { + for (int64_t i = 0; i < outer_size_; i++) { + std::vector perm = GetShuffleIndex(); + size_t offset = i * inner_size_ * SizeToLong(shuffle_size_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(workspace_addr + i * shuffle_size_, perm.data(), shuffle_size_ * sizeof(int), + cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream_)), + "RandomShuffle cudaMemcpy failed."); + ScalarShuffle(SizeToLong(shuffle_size_), workspace_addr, input_addr + offset, output_addr + offset, device_id_, + reinterpret_cast(cuda_stream_)); + } + } else { + for (int64_t i = 0; i < outer_size_; i++) { + std::vector perm = GetShuffleIndex(); + size_t offset = i * inner_size_ * SizeToLong(shuffle_size_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(workspace_addr + i * shuffle_size_, perm.data(), shuffle_size_ * sizeof(int), + cudaMemcpyHostToDevice, reinterpret_cast(cuda_stream_)), + "RandomShuffle cudaMemcpy failed."); + TensorShuffle(SizeToLong(shuffle_size_), inner_size_, workspace_addr, input_addr + offset, output_addr + offset, + device_id_, reinterpret_cast(cuda_stream_)); + } + } + + return true; +} + +std::vector> RandomShuffleGpuKernelMod::func_list_ = + {{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + &RandomShuffleGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &RandomShuffleGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &RandomShuffleGpuKernelMod::LaunchKernel>}}; + +std::vector RandomShuffleGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RandomShuffle, RandomShuffleGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h new file mode 100644 index 00000000000..7d8ef8efada --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h @@ -0,0 +1,80 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/random_shuffle.h" +#include "kernel/common_utils.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_op_impl.cuh" + +namespace mindspore { +namespace kernel { +class RandomShuffleGpuKernelMod : public NativeGpuKernelMod { + public: + RandomShuffleGpuKernelMod() = default; + ~RandomShuffleGpuKernelMod() override = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + std::vector GetShuffleIndex(); + + using RandomShuffleFunc = + std::function &, + const std::vector &, const std::vector &)>; + static std::vector> func_list_; + RandomShuffleFunc kernel_func_; + + int64_t outer_size_{1}; + int64_t inner_size_{1}; + size_t shuffle_size_{1}; + size_t batch_rank_{0}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + std::vector input_shape_; + std::default_random_engine generator_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_ diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_base.py b/mindspore/python/mindspore/ops/_vmap/vmap_base.py index c874bba9cbc..32e48ac7343 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_base.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_base.py @@ -26,7 +26,7 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import nn_ops as nps from mindspore.ops.composite import _VmapGeneralPreprocess from mindspore.ops.primitive import Primitive -from mindspore.ops.operations.random_ops import UniformCandidateSampler +from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry @@ -512,4 +512,5 @@ _ops_vmap_clone_prim_dict = { "SparseApplyAdagrad": P.SparseApplyAdagrad, "SparseApplyAdagradV2": P.SparseApplyAdagradV2, "SparseApplyFtrl": P.SparseApplyFtrl, + "RandomShuffle": RandomShuffle, } diff --git a/mindspore/python/mindspore/ops/function/random_func.py b/mindspore/python/mindspore/ops/function/random_func.py index 38b2d3c7fec..54943beec43 100755 --- a/mindspore/python/mindspore/ops/function/random_func.py +++ b/mindspore/python/mindspore/ops/function/random_func.py @@ -408,7 +408,7 @@ def random_shuffle(x, seed=0, seed2=0): TypeError: If data type of `seed` or `seed2` is not int. Supported Platforms: - ``CPU`` + ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32) diff --git a/mindspore/python/mindspore/ops/operations/random_ops.py b/mindspore/python/mindspore/ops/operations/random_ops.py index 8e28a94a375..7042506926c 100755 --- a/mindspore/python/mindspore/ops/operations/random_ops.py +++ b/mindspore/python/mindspore/ops/operations/random_ops.py @@ -942,7 +942,7 @@ class RandomShuffle(Primitive): TypeError: If data type of `seed` or `seed2` is not int. Supported Platforms: - ``CPU`` + ``CPU`` ``GPU`` Examples: >>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32) diff --git a/tests/st/ops/gpu/test_random_shuffle.py b/tests/st/ops/gpu/test_random_shuffle.py new file mode 100644 index 00000000000..2e6b70f4402 --- /dev/null +++ b/tests/st/ops/gpu/test_random_shuffle.py @@ -0,0 +1,139 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.nn as nn +from mindspore.ops import functional as F +from mindspore import Tensor, context + +context.set_context(device_target="GPU") + + +class RandomShuffleNet(nn.Cell): + def __init__(self, seed=0, seed2=0): + super(RandomShuffleNet, self).__init__() + self.seed = seed + self.seed2 = seed2 + + def construct(self, x): + return F.random_shuffle(x, self.seed, self.seed2) + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, + np.uint32, np.uint64, np.bool, np.complex64, np.complex128, + np.float64, np.float32, np.float16]) +def test_random_shuffle_op_dtype(mode, dtype): + """ + Feature: cpu RandomShuffle + Description: test the Tensor with all supported types. + Expectation: success. + """ + context.set_context(mode=mode) + + net = RandomShuffleNet(seed=1, seed2=1) + x = Tensor(np.array([1, 2, 3, 4, 5]).astype(dtype)) + expect_shape = (5,) + output = net(x) + assert output.shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize("shape", [(5,), (2, 3), (12, 3, 5), (3, 4, 2, 3), + (3, 4, 2, 3, 4), (3, 4, 2, 3, 4, 4), + (3, 4, 2, 3, 4, 5, 3)]) +def test_random_shuffle_op_tensor(mode, shape): + """ + Feature: cpu RandomShuffle + Description: test the 0-7D Tensor. + Expectation: success. + """ + context.set_context(mode=mode) + net = RandomShuffleNet(seed=3, seed2=1) + x = Tensor(np.random.randn(*shape).astype(np.float32)) + output = net(x) + expect_shape = shape + assert output.shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_random_shuffle_op_scalar(mode): + """ + Feature: cpu RandomShuffle + Description: test the scalar Tensor. + Expectation: success. + """ + context.set_context(mode=mode) + net = RandomShuffleNet(seed=3, seed2=1) + x = Tensor(np.array(2.5).astype(np.float32)) + output = net(x) + assert output == x + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_random_shuffle_op_dynamic_shape(mode): + """ + Feature: cpu RandomShuffle + Description: test the Tensor with dynamic shape. + Expectation: success. + """ + context.set_context(mode=mode) + dyn_net = RandomShuffleNet(seed=6, seed2=2) + net = RandomShuffleNet(seed=6, seed2=2) + x = Tensor(np.random.randn(3, 4, 5).astype(np.float32)) + x_dyn = Tensor(shape=[None for _ in x.shape], dtype=x.dtype) + dyn_net.set_inputs(x_dyn) + output_dyn = dyn_net(x) + out = net(x) + assert (output_dyn.asnumpy() == out.asnumpy()).all() + + +@pytest.mark.level0 +@pytest.mark.env_onecard +@pytest.mark.platform_x86_gpu_training +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_random_shuffle_op_exception(mode): + """ + Feature: cpu RandomShuffle + Description: test the Tensor with exception. + Expectation: success. + """ + context.set_context(mode=mode) + x = Tensor(np.random.randn(3, 4, 5).astype(np.float32)) + + with pytest.raises(TypeError): + F.random_shuffle(2, seed=3, seed2=1) + + with pytest.raises(ValueError): + F.random_shuffle(x, seed=-3, seed2=1) + + with pytest.raises(TypeError): + F.random_shuffle(x, seed=2, seed2=1.6) + + with pytest.raises(TypeError): + F.random_shuffle(x, seed=True, seed2=0)