forked from mindspore-Ecosystem/mindspore
!15140 Add GPU ops:depthtospace and spacetodepth
From: @kanghui0204 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
4631facee9
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/depthtospace_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DepthToSpaceFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
DepthToSpaceFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DepthToSpaceFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DepthToSpaceFwdKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
DepthToSpaceFwdKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
DepthToSpaceFwdKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
DepthToSpaceFwdKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
DepthToSpaceFwdKernel, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
DepthToSpaceFwdKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(DepthToSpace, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
DepthToSpaceFwdKernel, uint64_t)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,146 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/depthtospace_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DepthToSpaceFwdKernel : public GpuKernel {
|
||||
public:
|
||||
DepthToSpaceFwdKernel() { ResetResource(); }
|
||||
~DepthToSpaceFwdKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
// get device buffer ptr
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
// get device buffer shape ptr
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *output_shape = GetDeviceAddress<size_t>(workspace, 1);
|
||||
|
||||
// buffer shape memcpy from host to device
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
// get input size
|
||||
size_t size = input_size_ / sizeof(T);
|
||||
|
||||
// call cuda kernel
|
||||
CalDepthToSpace(size, input, input_shape, output_shape, block_size_, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
|
||||
// check input num and output num
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but DepthToSpace needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", DepthToSpace needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
// check input_shape
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
shape_size_ = input_shape.size();
|
||||
if (shape_size_ != DEPTHTOSPACE_BUFFER_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but DepthToSpace supports 4-D tensor.";
|
||||
}
|
||||
// get input and out put information
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
input_size_ *= sizeof(T);
|
||||
output_size_ = input_size_;
|
||||
output_shape_.push_back(input_shape[0]);
|
||||
output_shape_.push_back(input_shape[1] / block_size_ / block_size_);
|
||||
output_shape_.push_back(input_shape[2] * block_size_);
|
||||
output_shape_.push_back(input_shape[3] * block_size_);
|
||||
// Private members Initialize
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
shape_size_ = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
block_size_ = 0;
|
||||
workspace_size1_ = 0;
|
||||
workspace_size2_ = 0;
|
||||
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size1_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size2_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size_list_.push_back(workspace_size1_);
|
||||
workspace_size_list_.push_back(workspace_size2_);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
size_t shape_size_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t block_size_;
|
||||
size_t workspace_size1_;
|
||||
size_t workspace_size2_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEPTHTOSPACE_KERNEL_H_
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/spacetodepth_gpu_kernel.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpaceToDepthFwdKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpaceToDepthFwdKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpaceToDepthFwdKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpaceToDepthFwdKernel, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
SpaceToDepthFwdKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpaceToDepthFwdKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SpaceToDepthFwdKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
SpaceToDepthFwdKernel, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
SpaceToDepthFwdKernel, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(SpaceToDepth, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
SpaceToDepthFwdKernel, uint64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,146 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/spacetodepth_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SpaceToDepthFwdKernel : public GpuKernel {
|
||||
public:
|
||||
SpaceToDepthFwdKernel() { ResetResource(); }
|
||||
~SpaceToDepthFwdKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
// get device buffer ptr
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
// get device buffer shape ptr
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *output_shape = GetDeviceAddress<size_t>(workspace, 1);
|
||||
|
||||
// buffer shape memcpy from host to device
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size1_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(output_shape, &output_shape_[0], workspace_size2_,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
// get input size
|
||||
size_t size = input_size_ / sizeof(T);
|
||||
|
||||
// call cuda kernel
|
||||
CalSpaceToDepth(size, input, input_shape, output_shape, block_size_, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
|
||||
// check input num and output num
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but SpaceToDepth needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", SpaceToDepth needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
// check input_shape
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
shape_size_ = input_shape.size();
|
||||
if (shape_size_ != SPACETODEPTH_BUFFER_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but SpaceToDepth supports 4-D tensor.";
|
||||
}
|
||||
// get input and out put information
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
input_size_ *= sizeof(T);
|
||||
output_size_ = input_size_;
|
||||
output_shape_.push_back(input_shape[0]);
|
||||
output_shape_.push_back(input_shape[1] * block_size_ * block_size_);
|
||||
output_shape_.push_back(input_shape[2] / block_size_);
|
||||
output_shape_.push_back(input_shape[3] / block_size_);
|
||||
// Private members Initialize
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
shape_size_ = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
block_size_ = 0;
|
||||
workspace_size1_ = 0;
|
||||
workspace_size2_ = 0;
|
||||
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size1_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size2_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size_list_.push_back(workspace_size1_);
|
||||
workspace_size_list_.push_back(workspace_size2_);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
size_t shape_size_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t block_size_;
|
||||
size_t workspace_size1_;
|
||||
size_t workspace_size2_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACETODEPTH_KERNEL__H_
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "depthtospace_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void DepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output) {
|
||||
size_t temp_stride = 0;
|
||||
size_t temp_pos = 0;
|
||||
size_t input_pos = 0;
|
||||
size_t input_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION];
|
||||
size_t output_pos_array[DEPTHTOSPACE_BUFFER_DIMENSION];
|
||||
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
temp_stride = output_shape[1] * output_shape[2] * output_shape[3];
|
||||
output_pos_array[0] = pos / temp_stride;
|
||||
temp_pos = pos % temp_stride;
|
||||
|
||||
for (size_t i = 1; i < DEPTHTOSPACE_BUFFER_DIMENSION; i++) {
|
||||
temp_stride /= output_shape[i];
|
||||
output_pos_array[i] = temp_pos / temp_stride;
|
||||
temp_pos %= temp_stride;
|
||||
}
|
||||
|
||||
input_pos_array[0] = output_pos_array[0];
|
||||
input_pos_array[1] = output_pos_array[1] * r * r + r * (output_pos_array[2] % r) + output_pos_array[3] % r;
|
||||
input_pos_array[2] = output_pos_array[2] / r;
|
||||
input_pos_array[3] = output_pos_array[3] / r;
|
||||
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
input_pos += input_pos_array[i];
|
||||
input_pos *= input_shape[i + 1];
|
||||
}
|
||||
input_pos += input_pos_array[3];
|
||||
|
||||
output[pos] = input[input_pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output, cudaStream_t cuda_stream) {
|
||||
DepthToSpace<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, output_shape, r, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalDepthToSpace<float>(const size_t size, const float *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<half>(const size_t size, const half *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<int>(const size_t size, const int *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int64_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<int16_t>(const size_t size, const int16_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int16_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<int8_t>(const size_t size, const int8_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int8_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<uint8_t>(const size_t size, const uint8_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint8_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<uint16_t>(const size_t size, const uint16_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint16_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<uint32_t>(const size_t size, const uint32_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint32_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalDepthToSpace<uint64_t>(const size_t size, const uint64_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint64_t *output,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_
|
||||
|
||||
#define DEPTHTOSPACE_BUFFER_DIMENSION 4
|
||||
template <typename T>
|
||||
void CalDepthToSpace(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DEPTHTOSPACE_H_
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "spacetodepth_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void SpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output) {
|
||||
size_t temp_stride = 0;
|
||||
size_t temp_pos = 0;
|
||||
size_t output_pos = 0;
|
||||
size_t input_pos_array[SPACETODEPTH_BUFFER_DIMENSION];
|
||||
size_t output_pos_array[SPACETODEPTH_BUFFER_DIMENSION];
|
||||
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
temp_stride = input_shape[1] * input_shape[2] * input_shape[3];
|
||||
input_pos_array[0] = pos / temp_stride;
|
||||
temp_pos = pos % temp_stride;
|
||||
|
||||
for (size_t i = 1; i < SPACETODEPTH_BUFFER_DIMENSION; i++) {
|
||||
temp_stride /= input_shape[i];
|
||||
input_pos_array[i] = temp_pos / temp_stride;
|
||||
temp_pos %= temp_stride;
|
||||
}
|
||||
|
||||
output_pos_array[0] = input_pos_array[0];
|
||||
output_pos_array[1] = input_pos_array[1] * r * r + r * (input_pos_array[2] % r) + input_pos_array[3] % r;
|
||||
output_pos_array[2] = input_pos_array[2] / r;
|
||||
output_pos_array[3] = input_pos_array[3] / r;
|
||||
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
output_pos += output_pos_array[i];
|
||||
output_pos *= output_shape[i + 1];
|
||||
}
|
||||
output_pos += output_pos_array[3];
|
||||
|
||||
output[output_pos] = input[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output, cudaStream_t cuda_stream) {
|
||||
SpaceToDepth<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, input_shape, output_shape, r, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalSpaceToDepth<float>(const size_t size, const float *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<half>(const size_t size, const half *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, half *output, cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<int>(const size_t size, const int *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int *output, cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<int64_t>(const size_t size, const int64_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int64_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<int16_t>(const size_t size, const int16_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int16_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<int8_t>(const size_t size, const int8_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, int8_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<uint8_t>(const size_t size, const uint8_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint8_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<uint16_t>(const size_t size, const uint16_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint16_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<uint32_t>(const size_t size, const uint32_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint32_t *output,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalSpaceToDepth<uint64_t>(const size_t size, const uint64_t *input, const size_t *input_shape,
|
||||
const size_t *output_shape, const size_t r, uint64_t *output,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_
|
||||
|
||||
#define SPACETODEPTH_BUFFER_DIMENSION 4
|
||||
template <typename T>
|
||||
void CalSpaceToDepth(const size_t size, const T *input, const size_t *input_shape, const size_t *output_shape,
|
||||
const size_t r, T *output, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETODEPTH_H_
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
def DepthToSpaceNumpy(arr, block_size):
|
||||
'''
|
||||
DepthToSpace ops use numpy
|
||||
'''
|
||||
tmpshape = arr.shape
|
||||
newshape = []
|
||||
newshape.append(tmpshape[0])
|
||||
newshape.append(tmpshape[1]//block_size//block_size)
|
||||
newshape.append(tmpshape[2]*block_size)
|
||||
newshape.append(tmpshape[3]*block_size)
|
||||
output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3])
|
||||
output = np.transpose(output, (0, 1, 4, 2, 5, 3))
|
||||
output = output.reshape(newshape)
|
||||
return output
|
||||
|
||||
class DepthToSpaceNet(nn.Cell):
|
||||
def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
super(DepthToSpaceNet, self).__init__()
|
||||
self.DepthToSpace = P.DepthToSpace(2)
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
self.data_np = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
self.x = Parameter(initializer(Tensor(self.data_np), input_shape), name='x')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.DepthToSpace(self.x)
|
||||
|
||||
|
||||
def DepthToSpace(nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
expect = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
expect = DepthToSpaceNumpy(expect, block_size)
|
||||
|
||||
dts = DepthToSpaceNet(nptype, block_size, input_shape)
|
||||
output = dts()
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
def DepthToSpace_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
expect = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
expect = DepthToSpaceNumpy(expect, block_size)
|
||||
|
||||
dts = P.DepthToSpace(2)
|
||||
arr_input = Tensor(np.arange(input_size).reshape(input_shape).astype(nptype))
|
||||
output = dts(arr_input)
|
||||
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_float32():
|
||||
DepthToSpace(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_float16():
|
||||
DepthToSpace(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_int32():
|
||||
DepthToSpace(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_int64():
|
||||
DepthToSpace(np.int64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_int8():
|
||||
DepthToSpace(np.int8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_int16():
|
||||
DepthToSpace(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_uint8():
|
||||
DepthToSpace(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_uint16():
|
||||
DepthToSpace(np.uint16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_uint32():
|
||||
DepthToSpace(np.uint32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_graph_uint64():
|
||||
DepthToSpace(np.uint64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_float32():
|
||||
DepthToSpace_pynative(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_float16():
|
||||
DepthToSpace_pynative(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_int32():
|
||||
DepthToSpace_pynative(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_int64():
|
||||
DepthToSpace_pynative(np.int64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_int8():
|
||||
DepthToSpace_pynative(np.int8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_int16():
|
||||
DepthToSpace_pynative(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_uint8():
|
||||
DepthToSpace_pynative(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_uint16():
|
||||
DepthToSpace_pynative(np.uint16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_uint32():
|
||||
DepthToSpace_pynative(np.uint32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_depthtospace_pynative_uint64():
|
||||
DepthToSpace_pynative(np.uint64)
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
def DepthToSpaceNumpy(arr, block_size):
|
||||
'''
|
||||
DepthToSpace ops use numpy
|
||||
DepthToSpace ops is reverse ops to SpaceToDepth ops
|
||||
therefore DepthToSpace's output can be SpaceToDepth's input
|
||||
'''
|
||||
tmpshape = arr.shape
|
||||
newshape = []
|
||||
newshape.append(tmpshape[0])
|
||||
newshape.append(tmpshape[1]//block_size//block_size)
|
||||
newshape.append(tmpshape[2]*block_size)
|
||||
newshape.append(tmpshape[3]*block_size)
|
||||
output = arr.reshape(newshape[0], newshape[1], block_size, block_size, tmpshape[2], tmpshape[3])
|
||||
output = np.transpose(output, (0, 1, 4, 2, 5, 3))
|
||||
output = output.reshape(newshape)
|
||||
return output
|
||||
|
||||
class SpaceToDepthNet(nn.Cell):
|
||||
def __init__(self, nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
super(SpaceToDepthNet, self).__init__()
|
||||
self.SpaceToDepth = P.SpaceToDepth(block_size)
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
|
||||
data_np = np.arange(input_size).reshape(input_shape).astype(nptype)# data_np shape is (N,C,H,W)
|
||||
data_np = DepthToSpaceNumpy(data_np, block_size)#now data_np shape is (N,C/(block_size*block_size),H*block_size,W*block_size)
|
||||
self.data_np = data_np
|
||||
new_shape = []
|
||||
new_shape.append(input_shape[0])
|
||||
new_shape.append(input_shape[1]//(block_size*block_size))
|
||||
new_shape.append(input_shape[2]*block_size)
|
||||
new_shape.append(input_shape[3]*block_size)
|
||||
self.x = Parameter(initializer(Tensor(self.data_np), new_shape), name='x')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.SpaceToDepth(self.x)
|
||||
|
||||
|
||||
def SpaceToDepth(nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
|
||||
expect = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
|
||||
std = SpaceToDepthNet(nptype, block_size, input_shape)
|
||||
output = std()
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
def SpaceToDepth_pynative(nptype, block_size=2, input_shape=(1, 4, 3, 3)):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
input_size = 1
|
||||
for i in input_shape:
|
||||
input_size = input_size*i
|
||||
expect = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
arrinput = DepthToSpaceNumpy(expect, block_size)
|
||||
|
||||
std = P.SpaceToDepth(block_size)
|
||||
arrinput = Tensor(arrinput)
|
||||
output = std(arrinput)
|
||||
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_float32():
|
||||
SpaceToDepth(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_float16():
|
||||
SpaceToDepth(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_int32():
|
||||
SpaceToDepth(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_int64():
|
||||
SpaceToDepth(np.int64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_int8():
|
||||
SpaceToDepth(np.int8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_int16():
|
||||
SpaceToDepth(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_uint8():
|
||||
SpaceToDepth(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_uint16():
|
||||
SpaceToDepth(np.uint16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_uint32():
|
||||
SpaceToDepth(np.uint32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_graph_uint64():
|
||||
SpaceToDepth(np.uint64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_float32():
|
||||
SpaceToDepth_pynative(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_float16():
|
||||
SpaceToDepth_pynative(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_int32():
|
||||
SpaceToDepth_pynative(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_int64():
|
||||
SpaceToDepth_pynative(np.int64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_int8():
|
||||
SpaceToDepth_pynative(np.int8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_int16():
|
||||
SpaceToDepth_pynative(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_uint8():
|
||||
SpaceToDepth_pynative(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_uint16():
|
||||
SpaceToDepth_pynative(np.uint16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_uint32():
|
||||
SpaceToDepth_pynative(np.uint32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_spacetodepth_pynative_uint64():
|
||||
SpaceToDepth_pynative(np.uint64)
|
Loading…
Reference in New Issue