!40665 add randomshuffle_gpu

Merge pull request !40665 from 范吉斌/randomshuffle_gpu
This commit is contained in:
i-robot 2022-09-26 03:23:16 +00:00 committed by Gitee
commit 813b57f425
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 526 additions and 3 deletions

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "random_shuffle_impl.cuh"
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "include/cuda_fp16.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
#define SHUFFLE_DECLARE(type) \
template CUDA_LIB_EXPORT void ScalarShuffle<type>(const int64_t size, const int *perm, const type *input, \
type *output, const uint32_t device_id, cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT void TensorShuffle<type>(const int64_t shuffle_size, const int64_t inner_size, \
const int *perm, const type *input, type *output, \
const uint32_t device_id, cudaStream_t cuda_stream);
template <typename T>
__global__ void ScalarShuffleKernel(const int64_t size, const int *perm, const T *input, T *output) {
for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = input[perm[pos]];
}
}
template <typename T>
__global__ void TensorShuffleKernel(const int64_t shuffle_size, const int64_t inner_size, const int *perm,
const T *input, T *output) {
for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < shuffle_size * inner_size;
pos += blockDim.x * gridDim.x) {
int64_t row = pos / inner_size;
int64_t col = pos % inner_size;
int64_t output_offset = perm[row] * inner_size + col;
output[output_offset] = input[pos];
}
}
template <typename T>
void ScalarShuffle(const int64_t size, const int *perm, const T *input, T *output, const uint32_t device_id,
cudaStream_t cuda_stream) {
ScalarShuffleKernel<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, perm, input,
output);
}
template <typename T>
void TensorShuffle(const int64_t shuffle_size, const int64_t inner_size, const int *perm, const T *input, T *output,
const uint32_t device_id, cudaStream_t cuda_stream) {
int64_t total_size = shuffle_size * inner_size;
TensorShuffleKernel<<<CUDA_BLOCKS(device_id, total_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
shuffle_size, inner_size, perm, input, output);
}
SHUFFLE_DECLARE(half);
SHUFFLE_DECLARE(float);
SHUFFLE_DECLARE(double);
SHUFFLE_DECLARE(int8_t);
SHUFFLE_DECLARE(int16_t);
SHUFFLE_DECLARE(int32_t);
SHUFFLE_DECLARE(int64_t);
SHUFFLE_DECLARE(uint8_t);
SHUFFLE_DECLARE(uint16_t);
SHUFFLE_DECLARE(uint32_t);
SHUFFLE_DECLARE(uint64_t);
SHUFFLE_DECLARE(bool);
SHUFFLE_DECLARE(Complex<float>);
SHUFFLE_DECLARE(Complex<double>);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT void ScalarShuffle(const int64_t size, const int *perm, const T *input, T *output,
const uint32_t device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void TensorShuffle(const int64_t shuffle_size, const int64_t inner_size, const int *perm,
const T *input, T *output, const uint32_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RANDOM_SHUFFLE_IMPL_CUH_

View File

@ -0,0 +1,198 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/random/random_shuffle_gpu_kernel.h"
#include <functional>
#include <utility>
#include <memory>
#include <string>
#include <algorithm>
#include <complex>
#include "ir/anf.h"
#include "utils/log_adapter.h"
#include "kernel/common_utils.h"
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_shuffle_impl.cuh"
namespace mindspore {
namespace kernel {
namespace {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
constexpr size_t kRandomShuffleInputsNum = 1;
constexpr size_t kRandomShuffleOutputsNum = 1;
constexpr size_t kScalarShapeSize = 1;
} // namespace
bool RandomShuffleGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
auto kernel_ptr = std::make_shared<ops::RandomShuffle>(base_operator->GetPrim());
batch_rank_ = LongToSize(kernel_ptr->get_batch_rank());
auto seed = kernel_ptr->get_seed();
auto seed2 = kernel_ptr->get_seed2();
if (seed == 0 && seed2 == 0) {
std::random_device rd;
std::mt19937_64 gen(rd());
seed = gen();
} else {
seed = (seed == 0) ? seed2 : seed;
}
generator_.seed(seed);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "RandomShuffle does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return true;
}
int RandomShuffleGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRandomShuffleInputsNum, kernel_name_);
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
MS_EXCEPTION_IF_NULL(inputs[0]);
input_shape_ = inputs[0]->GetShapeVector();
if (!input_shape_.empty() && batch_rank_ >= input_shape_.size()) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the batch_rank should be less than input shape, but got batch_rank: " << batch_rank_
<< ", input shape: " << input_shape_;
return KRET_RESIZE_FAILED;
}
outer_size_ = 1;
for (size_t i = 0; i < batch_rank_; i++) {
outer_size_ *= input_shape_[i];
}
inner_size_ = 1;
for (size_t j = batch_rank_ + 1; j < input_shape_.size(); j++) {
inner_size_ *= input_shape_[j];
}
if (input_shape_.size() > batch_rank_) {
shuffle_size_ = LongToSize(input_shape_[batch_rank_]);
} else {
shuffle_size_ = 1;
}
workspace_size_list_.push_back(sizeof(int) * shuffle_size_ * outer_size_);
return ret;
}
std::vector<int> RandomShuffleGpuKernelMod::GetShuffleIndex() {
std::vector<int> perm(shuffle_size_);
int n = 0;
std::generate(perm.begin(), perm.end(), [&n] { return n++; });
std::shuffle(perm.begin(), perm.end(), generator_);
return perm;
}
template <typename T>
bool RandomShuffleGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kRandomShuffleInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kRandomShuffleOutputsNum, kernel_name_);
auto *input_addr = GetDeviceAddress<T>(inputs, 0);
auto *workspace_addr = GetDeviceAddress<int>(workspace, 0);
auto *output_addr = GetDeviceAddress<T>(outputs, 0);
if (input_shape_.empty() || input_shape_[batch_rank_] <= 1) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"RandomShuffle cudaMemcpy failed.");
return true;
}
if (input_shape_.size() <= batch_rank_ + kScalarShapeSize) {
for (int64_t i = 0; i < outer_size_; i++) {
std::vector<int> perm = GetShuffleIndex();
size_t offset = i * inner_size_ * SizeToLong(shuffle_size_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(workspace_addr + i * shuffle_size_, perm.data(), shuffle_size_ * sizeof(int),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"RandomShuffle cudaMemcpy failed.");
ScalarShuffle(SizeToLong(shuffle_size_), workspace_addr, input_addr + offset, output_addr + offset, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
}
} else {
for (int64_t i = 0; i < outer_size_; i++) {
std::vector<int> perm = GetShuffleIndex();
size_t offset = i * inner_size_ * SizeToLong(shuffle_size_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(workspace_addr + i * shuffle_size_, perm.data(), shuffle_size_ * sizeof(int),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"RandomShuffle cudaMemcpy failed.");
TensorShuffle(SizeToLong(shuffle_size_), inner_size_, workspace_addr, input_addr + offset, output_addr + offset,
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
}
}
return true;
}
std::vector<std::pair<KernelAttr, RandomShuffleGpuKernelMod::RandomShuffleFunc>> RandomShuffleGpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&RandomShuffleGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&RandomShuffleGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&RandomShuffleGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&RandomShuffleGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&RandomShuffleGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&RandomShuffleGpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&RandomShuffleGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&RandomShuffleGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&RandomShuffleGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&RandomShuffleGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&RandomShuffleGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&RandomShuffleGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&RandomShuffleGpuKernelMod::LaunchKernel<Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&RandomShuffleGpuKernelMod::LaunchKernel<Complex<double>>}};
std::vector<KernelAttr> RandomShuffleGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, RandomShuffleFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RandomShuffle, RandomShuffleGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,80 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_
#include <curand_kernel.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "mindspore/core/ops/random_shuffle.h"
#include "kernel/common_utils.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/random_op_impl.cuh"
namespace mindspore {
namespace kernel {
class RandomShuffleGpuKernelMod : public NativeGpuKernelMod {
public:
RandomShuffleGpuKernelMod() = default;
~RandomShuffleGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
std::vector<int> GetShuffleIndex();
using RandomShuffleFunc =
std::function<bool(RandomShuffleGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, RandomShuffleFunc>> func_list_;
RandomShuffleFunc kernel_func_;
int64_t outer_size_{1};
int64_t inner_size_{1};
size_t shuffle_size_{1};
size_t batch_rank_{0};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
std::vector<int64_t> input_shape_;
std::default_random_engine generator_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_SHUFFLE_GPU_KERNEL_H_

View File

@ -26,7 +26,7 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import nn_ops as nps
from mindspore.ops.composite import _VmapGeneralPreprocess
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations.random_ops import UniformCandidateSampler
from mindspore.ops.operations.random_ops import UniformCandidateSampler, RandomShuffle
from mindspore.ops._grad.grad_base import BpropRegistry as VmapRuleRegistry
@ -512,4 +512,5 @@ _ops_vmap_clone_prim_dict = {
"SparseApplyAdagrad": P.SparseApplyAdagrad,
"SparseApplyAdagradV2": P.SparseApplyAdagradV2,
"SparseApplyFtrl": P.SparseApplyFtrl,
"RandomShuffle": RandomShuffle,
}

View File

@ -408,7 +408,7 @@ def random_shuffle(x, seed=0, seed2=0):
TypeError: If data type of `seed` or `seed2` is not int.
Supported Platforms:
``CPU``
``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32)

View File

@ -919,7 +919,7 @@ class RandomShuffle(Primitive):
TypeError: If data type of `seed` or `seed2` is not int.
Supported Platforms:
``CPU``
``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32)

View File

@ -0,0 +1,139 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore import Tensor, context
context.set_context(device_target="GPU")
class RandomShuffleNet(nn.Cell):
def __init__(self, seed=0, seed2=0):
super(RandomShuffleNet, self).__init__()
self.seed = seed
self.seed2 = seed2
def construct(self, x):
return F.random_shuffle(x, self.seed, self.seed2)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize("dtype", [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
np.uint32, np.uint64, np.bool, np.complex64, np.complex128,
np.float64, np.float32, np.float16])
def test_random_shuffle_op_dtype(mode, dtype):
"""
Feature: cpu RandomShuffle
Description: test the Tensor with all supported types.
Expectation: success.
"""
context.set_context(mode=mode)
net = RandomShuffleNet(seed=1, seed2=1)
x = Tensor(np.array([1, 2, 3, 4, 5]).astype(dtype))
expect_shape = (5,)
output = net(x)
assert output.shape == expect_shape
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
@pytest.mark.parametrize("shape", [(5,), (2, 3), (12, 3, 5), (3, 4, 2, 3),
(3, 4, 2, 3, 4), (3, 4, 2, 3, 4, 4),
(3, 4, 2, 3, 4, 5, 3)])
def test_random_shuffle_op_tensor(mode, shape):
"""
Feature: cpu RandomShuffle
Description: test the 0-7D Tensor.
Expectation: success.
"""
context.set_context(mode=mode)
net = RandomShuffleNet(seed=3, seed2=1)
x = Tensor(np.random.randn(*shape).astype(np.float32))
output = net(x)
expect_shape = shape
assert output.shape == expect_shape
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_random_shuffle_op_scalar(mode):
"""
Feature: cpu RandomShuffle
Description: test the scalar Tensor.
Expectation: success.
"""
context.set_context(mode=mode)
net = RandomShuffleNet(seed=3, seed2=1)
x = Tensor(np.array(2.5).astype(np.float32))
output = net(x)
assert output == x
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_random_shuffle_op_dynamic_shape(mode):
"""
Feature: cpu RandomShuffle
Description: test the Tensor with dynamic shape.
Expectation: success.
"""
context.set_context(mode=mode)
dyn_net = RandomShuffleNet(seed=6, seed2=2)
net = RandomShuffleNet(seed=6, seed2=2)
x = Tensor(np.random.randn(3, 4, 5).astype(np.float32))
x_dyn = Tensor(shape=[None for _ in x.shape], dtype=x.dtype)
dyn_net.set_inputs(x_dyn)
output_dyn = dyn_net(x)
out = net(x)
assert (output_dyn.asnumpy() == out.asnumpy()).all()
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_random_shuffle_op_exception(mode):
"""
Feature: cpu RandomShuffle
Description: test the Tensor with exception.
Expectation: success.
"""
context.set_context(mode=mode)
x = Tensor(np.random.randn(3, 4, 5).astype(np.float32))
with pytest.raises(TypeError):
F.random_shuffle(2, seed=3, seed2=1)
with pytest.raises(ValueError):
F.random_shuffle(x, seed=-3, seed2=1)
with pytest.raises(TypeError):
F.random_shuffle(x, seed=2, seed2=1.6)
with pytest.raises(TypeError):
F.random_shuffle(x, seed=True, seed2=0)