forked from mindspore-Ecosystem/mindspore
!3176 Add gpu support for ResizeNearestNeighbor
Merge pull request !3176 from 34bunny/GPU-ResizeNearestNeighbor
This commit is contained in:
commit
38a52a5b67
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ResizeNearestNeighborGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ResizeNearestNeighborGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ResizeNearestNeighborGpuKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ResizeNearestNeighborGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ResizeNearestNeighborGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {}
|
||||
~ResizeNearestNeighborGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
int size = SizeToInt(output_size_ / sizeof(T));
|
||||
float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_);
|
||||
float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_);
|
||||
CalResizeNearestNeighbor(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output,
|
||||
output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_,
|
||||
h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
shape_size_ = input_shape.size();
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
|
||||
<< RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs.";
|
||||
return false;
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
input_size_ *= sizeof(T);
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
output_shape_.push_back(output_shape[i]);
|
||||
}
|
||||
output_size_ *= sizeof(T);
|
||||
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
float Scaling(const int in_size, const int out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
|
||||
bool align_corners_;
|
||||
size_t shape_size_;
|
||||
std::vector<int> input_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ResizeNearestNeighborGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ResizeNearestNeighborGradGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ResizeNearestNeighborGradGpuKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,111 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ResizeNearestNeighborGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
ResizeNearestNeighborGradGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {}
|
||||
~ResizeNearestNeighborGradGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
int size = SizeToInt(output_size_ / sizeof(T));
|
||||
float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_);
|
||||
float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_);
|
||||
CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3],
|
||||
output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3],
|
||||
align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
shape_size_ = input_shape.size();
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only "
|
||||
<< RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs.";
|
||||
return false;
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
input_size_ *= sizeof(T);
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < shape_size_; i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
output_shape_.push_back(output_shape[i]);
|
||||
}
|
||||
output_size_ *= sizeof(T);
|
||||
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
float Scaling(const int in_size, const int out_size, bool align_corners) {
|
||||
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
|
||||
: in_size / static_cast<float>(out_size);
|
||||
}
|
||||
|
||||
bool align_corners_;
|
||||
size_t shape_size_;
|
||||
std::vector<int> input_shape_;
|
||||
std::vector<int> output_shape_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
|
||||
const int s4, T *output, const int d1, const int d2, const int d3,
|
||||
const int d4, bool align_corners, float h_scale, float w_scale) {
|
||||
// initialization
|
||||
// HalfPixelCenters false
|
||||
int input_pos;
|
||||
int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION];
|
||||
int in_height = s3;
|
||||
int in_width = s4;
|
||||
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
|
||||
// pos_array[1] * output_shape[2] * output_shape[3] +
|
||||
// pos_array[2] * output_shape[3] +
|
||||
// pos_array[3]
|
||||
T h_scale_ = static_cast<T>(h_scale);
|
||||
T w_scale_ = static_cast<T>(w_scale);
|
||||
T out_h_;
|
||||
T out_w_;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
pos_array[0] = pos / (d2 * d3 * d4) % d1;
|
||||
pos_array[1] = pos / (d3 * d4) % d2;
|
||||
pos_array[2] = pos / (d4) % d3;
|
||||
pos_array[3] = pos % d4;
|
||||
out_h_ = static_cast<T>(pos_array[2]);
|
||||
out_w_ = static_cast<T>(pos_array[3]);
|
||||
const int in_y =
|
||||
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
|
||||
in_height - 1);
|
||||
const int in_x =
|
||||
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
|
||||
in_width - 1);
|
||||
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
|
||||
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;
|
||||
output[pos] = input[input_pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
|
||||
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
|
||||
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) {
|
||||
ResizeNearestNeighborGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||
size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalResizeNearestNeighborGrad<float>(const int size, const float *input, const int s1, const int s2,
|
||||
const int s3, const int s4, float *output, const int d1, const int d2,
|
||||
const int d3, const int d4, bool align_corners, float h_scale,
|
||||
float w_scale, cudaStream_t cuda_stream);
|
||||
template void CalResizeNearestNeighborGrad<half>(const int size, const half *input, const int s1, const int s2,
|
||||
const int s3, const int s4, half *output, const int d1, const int d2,
|
||||
const int d3, const int d4, bool align_corners, float h_scale,
|
||||
float w_scale, cudaStream_t cuda_stream);
|
||||
template void CalResizeNearestNeighborGrad<int>(const int size, const int *input, const int s1, const int s2,
|
||||
const int s3, const int s4, int *output, const int d1, const int d2,
|
||||
const int d3, const int d4, bool align_corners, float h_scale,
|
||||
float w_scale, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4
|
||||
|
||||
template <typename T>
|
||||
void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3,
|
||||
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
|
||||
bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3,
|
||||
const int s4, T *output, const int d1, const int d2, const int d3, const int d4,
|
||||
bool align_corners, float h_scale, float w_scale) {
|
||||
// initialization
|
||||
// HalfPixelCenters false
|
||||
int input_pos;
|
||||
int pos_array[RESIZENEARESTNEIGHBOR_DIMENSION];
|
||||
int in_height = s3;
|
||||
int in_width = s4;
|
||||
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
|
||||
// pos_array[1] * output_shape[2] * output_shape[3] +
|
||||
// pos_array[2] * output_shape[3] +
|
||||
// pos_array[3]
|
||||
T h_scale_ = static_cast<T>(h_scale);
|
||||
T w_scale_ = static_cast<T>(w_scale);
|
||||
T out_h_;
|
||||
T out_w_;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
pos_array[0] = pos / (d2 * d3 * d4) % d1;
|
||||
pos_array[1] = pos / (d3 * d4) % d2;
|
||||
pos_array[2] = pos / (d4) % d3;
|
||||
pos_array[3] = pos % d4;
|
||||
out_h_ = static_cast<T>(pos_array[2]);
|
||||
out_w_ = static_cast<T>(pos_array[3]);
|
||||
const int in_y =
|
||||
min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)),
|
||||
in_height - 1);
|
||||
const int in_x =
|
||||
min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)),
|
||||
in_width - 1);
|
||||
// pos_array[0] N, pos_array[1] C, in_y H, in_x W
|
||||
input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x;
|
||||
output[pos] = input[input_pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4,
|
||||
T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners,
|
||||
float h_scale, float w_scale, cudaStream_t cuda_stream) {
|
||||
ResizeNearestNeighbor<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, s1, s2, s3, s4, output, d1, d2,
|
||||
d3, d4, align_corners, h_scale, w_scale);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalResizeNearestNeighbor<float>(const int size, const float *input, const int s1, const int s2,
|
||||
const int s3, const int s4, float *output, const int d1, const int d2,
|
||||
const int d3, const int d4, bool align_corners, float h_scale,
|
||||
float w_scale, cudaStream_t cuda_stream);
|
||||
template void CalResizeNearestNeighbor<half>(const int size, const half *input, const int s1, const int s2,
|
||||
const int s3, const int s4, half *output, const int d1, const int d2,
|
||||
const int d3, const int d4, bool align_corners, float h_scale,
|
||||
float w_scale, cudaStream_t cuda_stream);
|
||||
template void CalResizeNearestNeighbor<int>(const int size, const int *input, const int s1, const int s2, const int s3,
|
||||
const int s4, int *output, const int d1, const int d2, const int d3,
|
||||
const int d4, bool align_corners, float h_scale, float w_scale,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_
|
||||
#include <cuda_runtime.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#define RESIZENEARESTNEIGHBOR_DIMENSION 4
|
||||
|
||||
template <typename T>
|
||||
void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4,
|
||||
T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners,
|
||||
float h_scale, float w_scale, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_
|
|
@ -2338,10 +2338,10 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
|
|||
and output tensors are aligned. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
||||
- **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`.
|
||||
Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
|
@ -2360,7 +2360,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
|
||||
|
||||
def infer_shape(self, x):
|
||||
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name)
|
||||
validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
|
||||
return tuple(x)[:-2] + tuple(self.size)
|
||||
|
||||
def infer_dtype(self, x):
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
class ResizeNearestNeighborGradAlignCornerT(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ResizeNearestNeighborGradAlignCornerT, self).__init__()
|
||||
self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(align_corners=True)
|
||||
|
||||
def construct(self, dy, size):
|
||||
return self.ResizeNearestNeighborGradAlignCornerT(dy, size)
|
||||
|
||||
class ResizeNearestNeighborGradAlignCornerF(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ResizeNearestNeighborGradAlignCornerF, self).__init__()
|
||||
self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(align_corners=False)
|
||||
|
||||
def construct(self, dy, size):
|
||||
return self.ResizeNearestNeighborGradAlignCornerF(dy, size)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborGradAlignCornerT():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
size = (2, 2)
|
||||
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborGradAlignCornerT()
|
||||
output = rnn(Tensor(dy), size)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
size = (2, 2)
|
||||
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborGradAlignCornerT()
|
||||
output = rnn(Tensor(dy), size)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborGradAlignCornerF():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
size = (2, 2)
|
||||
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborGradAlignCornerF()
|
||||
output = rnn(Tensor(dy), size)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
size = (2, 2)
|
||||
expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborGradAlignCornerF()
|
||||
output = rnn(Tensor(dy), size)
|
||||
assert np.all(output.asnumpy() == expect)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class ResizeNearestNeighborAlignCornerT(nn.Cell):
|
||||
def __init__(self, size):
|
||||
super(ResizeNearestNeighborAlignCornerT, self).__init__()
|
||||
self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True)
|
||||
|
||||
def construct(self, x):
|
||||
return self.ResizeNearestNeighborAlignCornerT(x)
|
||||
|
||||
class ResizeNearestNeighborAlignCornerF(nn.Cell):
|
||||
def __init__(self, size):
|
||||
super(ResizeNearestNeighborAlignCornerF, self).__init__()
|
||||
self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False)
|
||||
|
||||
def construct(self, x):
|
||||
return self.ResizeNearestNeighborAlignCornerF(x)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborAlignCornerT():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborAlignCornerT((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ResizeNearestNeighborAlignCornerF():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
|
||||
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16))
|
||||
expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
|
||||
rnn = ResizeNearestNeighborAlignCornerF((4, 4))
|
||||
output = rnn(input_tensor)
|
||||
assert np.all(output.asnumpy() == expect)
|
Loading…
Reference in New Issue