!3176 Add gpu support for ResizeNearestNeighbor

Merge pull request !3176 from 34bunny/GPU-ResizeNearestNeighbor
This commit is contained in:
mindspore-ci-bot 2020-07-22 13:38:11 +08:00 committed by Gitee
commit 38a52a5b67
11 changed files with 651 additions and 3 deletions

View File

@ -0,0 +1,31 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeNearestNeighborGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeNearestNeighborGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ResizeNearestNeighborGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ResizeNearestNeighborGpuKernel : public GpuKernel {
public:
ResizeNearestNeighborGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {}
~ResizeNearestNeighborGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int size = SizeToInt(output_size_ / sizeof(T));
float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_);
float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_);
CalResizeNearestNeighbor(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output,
output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_,
h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
shape_size_ = input_shape.size();
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs.";
return false;
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];
input_shape_.push_back(input_shape[i]);
}
input_size_ *= sizeof(T);
output_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
output_size_ *= output_shape[i];
output_shape_.push_back(output_shape[i]);
}
output_size_ *= sizeof(T);
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
float Scaling(const int in_size, const int out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);
}
bool align_corners_;
size_t shape_size_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_

View File

@ -0,0 +1,31 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeNearestNeighborGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeNearestNeighborGradGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ResizeNearestNeighborGradGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ResizeNearestNeighborGradGpuKernel : public GpuKernel {
public:
ResizeNearestNeighborGradGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {}
~ResizeNearestNeighborGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int size = SizeToInt(output_size_ / sizeof(T));
float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_);
float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_);
CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3],
output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3],
align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
shape_size_ = input_shape.size();
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
<< RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs.";
return false;
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];
input_shape_.push_back(input_shape[i]);
}
input_size_ *= sizeof(T);
output_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
output_size_ *= output_shape[i];
output_shape_.push_back(output_shape[i]);
}
output_size_ *= sizeof(T);
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
float Scaling(const int in_size, const int out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);
}
bool align_corners_;
size_t shape_size_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdint.h>
#include <math.h>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh"
template <typename T>
__global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3,
const int d4, bool align_corners, float h_scale, float w_scale) {
// initialization
// HalfPixelCenters false
int input_pos;
int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION];
int in_height = s3;
int in_width = s4;
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
// pos_array[1] * output_shape[2] * output_shape[3] +
// pos_array[2] * output_shape[3] +
// pos_array[3]
T h_scale_ = static_cast<T>(h_scale);
T w_scale_ = static_cast<T>(w_scale);
T out_h_;
T out_w_;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
pos_array[0] = pos / (d2 * d3 * d4) % d1;
pos_array[1] = pos / (d3 * d4) % d2;
pos_array[2] = pos / (d4) % d3;
pos_array[3] = pos % d4;
out_h_ = static_cast<T>(pos_array[2]);
out_w_ = static_cast<T>(pos_array[3]);
const int in_y =
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
in_height - 1);
const int in_x =
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
in_width - 1);
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;
output[pos] = input[input_pos];
}
return;
}
template <typename T>
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) {
ResizeNearestNeighborGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale);
return;
}
template void CalResizeNearestNeighborGrad<float>(const int size, const float *input, const int s1, const int s2,
const int s3, const int s4, float *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighborGrad<half>(const int size, const half *input, const int s1, const int s2,
const int s3, const int s4, half *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighborGrad<int>(const int size, const int *input, const int s1, const int s2,
const int s3, const int s4, int *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);

View File

@ -0,0 +1,28 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
#define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4
template <typename T>
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_

View File

@ -0,0 +1,81 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdint.h>
#include <math.h>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh"
template <typename T>
__global__ void ResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3,
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
bool align_corners, float h_scale, float w_scale) {
// initialization
// HalfPixelCenters false
int input_pos;
int pos_array[RESIZENEARESTNEIGHBOR_DIMENSION];
int in_height = s3;
int in_width = s4;
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
// pos_array[1] * output_shape[2] * output_shape[3] +
// pos_array[2] * output_shape[3] +
// pos_array[3]
T h_scale_ = static_cast<T>(h_scale);
T w_scale_ = static_cast<T>(w_scale);
T out_h_;
T out_w_;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
pos_array[0] = pos / (d2 * d3 * d4) % d1;
pos_array[1] = pos / (d3 * d4) % d2;
pos_array[2] = pos / (d4) % d3;
pos_array[3] = pos % d4;
out_h_ = static_cast<T>(pos_array[2]);
out_w_ = static_cast<T>(pos_array[3]);
const int in_y =
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
in_height - 1);
const int in_x =
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
in_width - 1);
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;
output[pos] = input[input_pos];
}
return;
}
template <typename T>
void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4,
T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners,
float h_scale, float w_scale, cudaStream_t cuda_stream) {
ResizeNearestNeighbor<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, s1, s2, s3, s4, output, d1, d2,
d3, d4, align_corners, h_scale, w_scale);
return;
}
template void CalResizeNearestNeighbor<float>(const int size, const float *input, const int s1, const int s2,
const int s3, const int s4, float *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighbor<half>(const int size, const half *input, const int s1, const int s2,
const int s3, const int s4, half *output, const int d1, const int d2,
const int d3, const int d4, bool align_corners, float h_scale,
float w_scale, cudaStream_t cuda_stream);
template void CalResizeNearestNeighbor<int>(const int size, const int *input, const int s1, const int s2, const int s3,
const int s4, int *output, const int d1, const int d2, const int d3,
const int d4, bool align_corners, float h_scale, float w_scale,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,28 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
#define RESIZENEARESTNEIGHBOR_DIMENSION 4
template <typename T>
void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4,
T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners,
float h_scale, float w_scale, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_

View File

@ -2338,10 +2338,10 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
and output tensors are aligned. Default: False.
Inputs:
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
Outputs:
Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`.
Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
Examples:
>>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
@ -2360,7 +2360,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x):
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name)
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
return tuple(x)[:-2] + tuple(self.size)
def infer_dtype(self, x):

View File

@ -0,0 +1,75 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class ResizeNearestNeighborGradAlignCornerT(nn.Cell):
def __init__(self):
super(ResizeNearestNeighborGradAlignCornerT, self).__init__()
self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(align_corners=True)
def construct(self, dy, size):
return self.ResizeNearestNeighborGradAlignCornerT(dy, size)
class ResizeNearestNeighborGradAlignCornerF(nn.Cell):
def __init__(self):
super(ResizeNearestNeighborGradAlignCornerF, self).__init__()
self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(align_corners=False)
def construct(self, dy, size):
return self.ResizeNearestNeighborGradAlignCornerF(dy, size)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ResizeNearestNeighborGradAlignCornerT():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerT()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerT()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ResizeNearestNeighborGradAlignCornerF():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerF()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
size = (2, 2)
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerF()
output = rnn(Tensor(dy), size)
assert np.all(output.asnumpy() == expect)

View File

@ -0,0 +1,71 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class ResizeNearestNeighborAlignCornerT(nn.Cell):
def __init__(self, size):
super(ResizeNearestNeighborAlignCornerT, self).__init__()
self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True)
def construct(self, x):
return self.ResizeNearestNeighborAlignCornerT(x)
class ResizeNearestNeighborAlignCornerF(nn.Cell):
def __init__(self, size):
super(ResizeNearestNeighborAlignCornerF, self).__init__()
self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False)
def construct(self, x):
return self.ResizeNearestNeighborAlignCornerF(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ResizeNearestNeighborAlignCornerT():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ResizeNearestNeighborAlignCornerF():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
output = rnn(input_tensor)
assert np.all(output.asnumpy() == expect)