add maxunpool2d,maxunpool2dGrad GPU
This commit is contained in:
parent
33f12ac66b
commit
bbe1c7e796
|
@ -0,0 +1,270 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_MAXUNPOOL2D_HELPER_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_MAXUNPOOL2D_HELPER_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxunpool2d_impl.cuh"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxunpool2d_grad_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace cukernel {
|
||||||
|
class MaxUnpool2DAttr : public GpuKernelAttrBase {
|
||||||
|
public:
|
||||||
|
MaxUnpool2DAttr() = default;
|
||||||
|
~MaxUnpool2DAttr() override = default;
|
||||||
|
std::vector<int64_t> ksize;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
|
std::vector<int64_t> pads;
|
||||||
|
std::vector<int64_t> output_shape;
|
||||||
|
std::string data_format;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
class MaxUnpool2DHelperGpuKernel : public GpuKernelHelperBase {
|
||||||
|
public:
|
||||||
|
explicit MaxUnpool2DHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||||
|
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||||
|
is_null_input_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~MaxUnpool2DHelperGpuKernel() = default;
|
||||||
|
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||||
|
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||||
|
constexpr size_t OUTPUT_NUM = 1;
|
||||||
|
ResetResource();
|
||||||
|
|
||||||
|
input_shape_ = input_shapes[kIndex0];
|
||||||
|
indices_shape_ = input_shapes[kIndex1];
|
||||||
|
|
||||||
|
size_t cur_size_T = sizeof(T);
|
||||||
|
for (const auto &val : input_shape_) {
|
||||||
|
cur_size_T *= val;
|
||||||
|
}
|
||||||
|
input_size_list_.emplace_back(cur_size_T);
|
||||||
|
|
||||||
|
size_t cur_size_S = sizeof(S);
|
||||||
|
for (const auto &val : indices_shape_) {
|
||||||
|
cur_size_S *= val;
|
||||||
|
}
|
||||||
|
input_size_list_.emplace_back(cur_size_S);
|
||||||
|
work_size_list_.emplace_back(sizeof(int64_t));
|
||||||
|
int out_flag =
|
||||||
|
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||||
|
if (out_flag == -1) {
|
||||||
|
return out_flag;
|
||||||
|
}
|
||||||
|
output_shape_ = output_shapes[kIndex0];
|
||||||
|
is_null_input_ = (out_flag == 1);
|
||||||
|
return CheckKernelParam();
|
||||||
|
}
|
||||||
|
|
||||||
|
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||||
|
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
T *input_ptr = nullptr;
|
||||||
|
S *indices = nullptr;
|
||||||
|
T *output_ptr = nullptr;
|
||||||
|
int64_t *gpuflag = nullptr;
|
||||||
|
int flag = GetDeviceAddress<T>(input_ptrs, kIndex0, kernel_name_, &input_ptr);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<S>(input_ptrs, kIndex1, kernel_name_, &indices);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<T>(output_ptrs, kIndex0, kernel_name_, &output_ptr);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<int64_t>(work_ptrs, kIndex0, kernel_name_, &gpuflag);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dims = static_cast<int64_t>(input_shape_.size());
|
||||||
|
int64_t outer_size = 1;
|
||||||
|
for (int64_t i = dims - 1; i >= 0; i--) {
|
||||||
|
outer_size *= output_shape_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t thread_size = 1;
|
||||||
|
for (int64_t i = dims - 1; i >= 0; i--) {
|
||||||
|
thread_size *= input_shape_[i];
|
||||||
|
}
|
||||||
|
CalMaxUnpool2D(input_ptr, indices, input_shape_, output_shape_, output_ptr, outer_size, thread_size, data_format_,
|
||||||
|
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||||
|
attr_ptr_ = std::dynamic_pointer_cast<MaxUnpool2DAttr>(kernel_attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int CheckKernelParam() override {
|
||||||
|
data_format_ = attr_ptr_->data_format;
|
||||||
|
if (data_format_ != "NCHW" && data_format_ != "NHWC") {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'data_format' must be 'NCHW' or 'NHWC' ,"
|
||||||
|
<< " but got " << data_format_;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
data_format_ = attr_ptr_->data_format;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<MaxUnpool2DAttr> attr_ptr_;
|
||||||
|
std::vector<int64_t> input_shape_;
|
||||||
|
std::vector<int64_t> indices_shape_;
|
||||||
|
std::vector<int64_t> output_shape_;
|
||||||
|
std::string data_format_;
|
||||||
|
bool is_null_input_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MaxUnpool2DGradAttr : public GpuKernelAttrBase {
|
||||||
|
public:
|
||||||
|
MaxUnpool2DGradAttr() = default;
|
||||||
|
~MaxUnpool2DGradAttr() override = default;
|
||||||
|
std::vector<int64_t> ksize;
|
||||||
|
std::vector<int64_t> strides;
|
||||||
|
std::vector<int64_t> pads;
|
||||||
|
std::vector<int64_t> output_shape;
|
||||||
|
std::string data_format;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
class MaxUnpool2DGradHelperGpuKernel : public GpuKernelHelperBase {
|
||||||
|
public:
|
||||||
|
explicit MaxUnpool2DGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||||
|
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||||
|
is_null_input_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~MaxUnpool2DGradHelperGpuKernel() = default;
|
||||||
|
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||||
|
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||||
|
constexpr size_t OUTPUT_NUM = 1;
|
||||||
|
ResetResource();
|
||||||
|
|
||||||
|
backprop_input_shape_ = input_shapes[kIndex0];
|
||||||
|
grad_shape_ = input_shapes[kIndex1];
|
||||||
|
indices_shape_ = input_shapes[kIndex2];
|
||||||
|
|
||||||
|
size_t cur_size_T = sizeof(T);
|
||||||
|
for (const auto &val : backprop_input_shape_) {
|
||||||
|
cur_size_T *= val;
|
||||||
|
}
|
||||||
|
input_size_list_.emplace_back(cur_size_T);
|
||||||
|
|
||||||
|
cur_size_T = sizeof(T);
|
||||||
|
for (const auto &val : grad_shape_) {
|
||||||
|
cur_size_T *= val;
|
||||||
|
}
|
||||||
|
input_size_list_.emplace_back(cur_size_T);
|
||||||
|
|
||||||
|
size_t cur_size_S = sizeof(S);
|
||||||
|
for (const auto &val : indices_shape_) {
|
||||||
|
cur_size_S *= val;
|
||||||
|
}
|
||||||
|
input_size_list_.emplace_back(cur_size_S);
|
||||||
|
work_size_list_.emplace_back(sizeof(int64_t));
|
||||||
|
int out_flag =
|
||||||
|
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||||
|
if (out_flag == -1) {
|
||||||
|
return out_flag;
|
||||||
|
}
|
||||||
|
backprop_output_shape_ = output_shapes[kIndex0];
|
||||||
|
is_null_input_ = (out_flag == 1);
|
||||||
|
return CheckKernelParam();
|
||||||
|
}
|
||||||
|
|
||||||
|
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||||
|
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||||
|
if (is_null_input_) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
T *input_ptr = nullptr;
|
||||||
|
T *grad = nullptr;
|
||||||
|
S *indices = nullptr;
|
||||||
|
T *output_ptr = nullptr;
|
||||||
|
int64_t *gpuflag = nullptr;
|
||||||
|
int flag = GetDeviceAddress<T>(input_ptrs, kIndex0, kernel_name_, &input_ptr);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<T>(input_ptrs, kIndex1, kernel_name_, &grad);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<S>(input_ptrs, kIndex2, kernel_name_, &indices);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<T>(output_ptrs, kIndex0, kernel_name_, &output_ptr);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
flag = GetDeviceAddress<int64_t>(work_ptrs, kIndex0, kernel_name_, &gpuflag);
|
||||||
|
if (flag != 0) {
|
||||||
|
return flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t dims = static_cast<int64_t>(backprop_input_shape_.size());
|
||||||
|
int64_t outer_size = 1;
|
||||||
|
for (int64_t i = dims - 1; i >= 0; i--) {
|
||||||
|
outer_size *= backprop_output_shape_[i];
|
||||||
|
}
|
||||||
|
CalMaxUnpool2DGrad(grad, indices, backprop_input_shape_, grad_shape_, output_ptr, outer_size, data_format_,
|
||||||
|
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||||
|
attr_ptr_ = std::dynamic_pointer_cast<MaxUnpool2DGradAttr>(kernel_attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int CheckKernelParam() override {
|
||||||
|
data_format_ = attr_ptr_->data_format;
|
||||||
|
if (data_format_ != "NCHW" && data_format_ != "NHWC") {
|
||||||
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'data_format' must be 'NCHW' or 'NHWC' ,"
|
||||||
|
<< " but got " << data_format_;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
data_format_ = attr_ptr_->data_format;
|
||||||
|
return 0;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<MaxUnpool2DGradAttr> attr_ptr_;
|
||||||
|
std::vector<int64_t> backprop_input_shape_;
|
||||||
|
std::vector<int64_t> grad_shape_;
|
||||||
|
std::vector<int64_t> indices_shape_;
|
||||||
|
std::vector<int64_t> backprop_output_shape_;
|
||||||
|
std::string data_format_;
|
||||||
|
bool is_null_input_;
|
||||||
|
};
|
||||||
|
} // namespace cukernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ARGMAX_HELPER_H_
|
|
@ -0,0 +1,205 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxunpool2d_grad_impl.cuh"
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
__global__ void MaxUnpool2DGradNCHW(const T *grad, const S *indices, const int64_t inputChannel,
|
||||||
|
const int64_t inputHeight, const int64_t inputWidth, const int64_t outputChannel,
|
||||||
|
const int64_t outputHeight, const int64_t outputWidth, const int64_t outer_size,
|
||||||
|
T *output) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
const int posn = pos / (inputHeight * inputWidth * inputChannel);
|
||||||
|
const int posc = pos / (inputWidth * inputHeight) % inputChannel;
|
||||||
|
S maxind = indices[pos];
|
||||||
|
output[pos] = grad[maxind + (posn * inputChannel + posc) * outputHeight * outputWidth];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
__global__ void MaxUnpool2DGradNHWC(const T *grad, const S *indices, const int64_t inputHeight,
|
||||||
|
const int64_t inputWidth, const int64_t inputChannel, const int64_t outputHeight,
|
||||||
|
const int64_t outputWidth, const int64_t outputChannel, const int64_t outer_size,
|
||||||
|
T *output) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
const int posn = pos / (inputHeight * inputWidth * inputChannel);
|
||||||
|
const int posc = pos % inputChannel;
|
||||||
|
S maxind = indices[pos];
|
||||||
|
output[pos] = grad[(posn * outputHeight * outputWidth + maxind) * outputChannel + posc];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void CalMaxUnpool2DGrad(const T *grad, const S *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, T *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
if (data_format_ == "NCHW") {
|
||||||
|
MaxUnpool2DGradNCHW<<<CUDA_BLOCKS(device_id, outer_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||||
|
grad, indices, backprop_input_shape[1], backprop_input_shape[2], backprop_input_shape[3], grad_shape[1],
|
||||||
|
grad_shape[2], grad_shape[3], outer_size, output);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
MaxUnpool2DGradNHWC<<<CUDA_BLOCKS(device_id, outer_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||||
|
grad, indices, backprop_input_shape[1], backprop_input_shape[2], backprop_input_shape[3], grad_shape[1],
|
||||||
|
grad_shape[2], grad_shape[3], outer_size, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint8_t, int32_t>(const uint8_t *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
uint8_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint8_t, int64_t>(const uint8_t *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
uint8_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint16_t, int32_t>(
|
||||||
|
const uint16_t *grad, const int32_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint16_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint16_t, int64_t>(
|
||||||
|
const uint16_t *grad, const int64_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint16_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint32_t, int32_t>(
|
||||||
|
const uint32_t *grad, const int32_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint32_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint32_t, int64_t>(
|
||||||
|
const uint32_t *grad, const int64_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint32_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint64_t, int32_t>(
|
||||||
|
const uint64_t *grad, const int32_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint64_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<uint64_t, int64_t>(
|
||||||
|
const uint64_t *grad, const int64_t *indices, const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, uint64_t *output, const int64_t outer_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int8_t, int32_t>(const int8_t *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, int8_t *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int8_t, int64_t>(const int8_t *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, int8_t *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int16_t, int32_t>(const int16_t *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int16_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int16_t, int64_t>(const int16_t *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int16_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int32_t, int32_t>(const int32_t *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int32_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int32_t, int64_t>(const int32_t *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int32_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int64_t, int32_t>(const int64_t *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int64_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<int64_t, int64_t>(const int64_t *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape,
|
||||||
|
int64_t *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<half, int32_t>(const half *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, half *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<half, int64_t>(const half *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, half *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<float, int32_t>(const float *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, float *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<float, int64_t>(const float *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, float *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<double, int32_t>(const double *grad, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, double *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2DGrad<double, int64_t>(const double *grad, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, double *output,
|
||||||
|
const int64_t outer_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,31 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXUNPOOL2DGRAD_IMPL_CUH_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXUNPOOL2DGRAD_IMPL_CUH_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||||
|
#include "include/cuda_fp16.h"
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
CUDA_LIB_EXPORT void CalMaxUnpool2DGrad(const T *grad, const S *indices,
|
||||||
|
const std::vector<int64_t> backprop_input_shape,
|
||||||
|
const std::vector<int64_t> grad_shape, T *output, const int64_t outer_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_
|
|
@ -0,0 +1,224 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/maxunpool2d_impl.cuh"
|
||||||
|
template <typename T>
|
||||||
|
__global__ void InitMaxUnpool2D(const int64_t outer_size, T *output) {
|
||||||
|
T zero = 0;
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
output[pos] = zero;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
__global__ void MaxUnpool2DNCHW(const T *input, const S *indices, const int64_t inputBatch, const int64_t inputChannel,
|
||||||
|
const int64_t inputHeight, const int64_t inputWidth, const int64_t outputChannel,
|
||||||
|
const int64_t outputHeight, const int64_t outputWidth, const int64_t thread_size,
|
||||||
|
T *output) {
|
||||||
|
int posn = blockIdx.z;
|
||||||
|
int posc = blockIdx.y;
|
||||||
|
output += (posn * inputChannel + posc) * outputHeight * outputWidth;
|
||||||
|
input += (posn * inputChannel + posc) * inputHeight * inputWidth;
|
||||||
|
indices += (posn * inputChannel + posc) * inputHeight * inputWidth;
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
S maxind = indices[pos];
|
||||||
|
CUDA_KERNEL_ASSERT(maxind >= 0 && maxind < outputHeight * outputWidth);
|
||||||
|
output[maxind] = input[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
__global__ void MaxUnpool2DNHWC(const T *input, const S *indices, const int64_t inputBatch, const int64_t inputHeight,
|
||||||
|
const int64_t inputWidth, const int64_t inputChannel, const int64_t outputHeight,
|
||||||
|
const int64_t outputWidth, const int64_t outputChannel, const int64_t thread_size,
|
||||||
|
T *output) {
|
||||||
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_size; pos += blockDim.x * gridDim.x) {
|
||||||
|
const int posn = pos / (inputHeight * inputWidth * inputChannel);
|
||||||
|
const int posc = pos % inputChannel;
|
||||||
|
S maxind = indices[pos];
|
||||||
|
CUDA_KERNEL_ASSERT(maxind >= 0 && maxind < inputChannel * outputHeight * outputWidth);
|
||||||
|
output[(posn * outputHeight * outputWidth + maxind) * outputChannel + posc] = input[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
void CalMaxUnpool2D(const T *input, const S *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, T *output, const int64_t outer_size,
|
||||||
|
const int64_t thread_size, const std::string data_format_, const uint32_t &device_id,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
InitMaxUnpool2D<<<CUDA_BLOCKS(device_id, outer_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(outer_size, output);
|
||||||
|
if (data_format_ == "NCHW") {
|
||||||
|
int outputPlaneSize = input_shape[2] * input_shape[3];
|
||||||
|
dim3 grid((outputPlaneSize + 127) / 128, input_shape[1], input_shape[0]);
|
||||||
|
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
|
||||||
|
MaxUnpool2DNCHW<<<grid, block, 0, cuda_stream>>>(
|
||||||
|
input, indices, input_shape[0], input_shape[1], input_shape[2], input_shape[3], output_shape[1], output_shape[2],
|
||||||
|
output_shape[3], outputPlaneSize, output);
|
||||||
|
} else {
|
||||||
|
MaxUnpool2DNHWC<<<CUDA_BLOCKS(device_id, thread_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||||
|
input, indices, input_shape[0], input_shape[1], input_shape[2], input_shape[3], output_shape[1], output_shape[2],
|
||||||
|
output_shape[3], thread_size, output);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint8_t, int32_t>(const uint8_t *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint8_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint8_t, int64_t>(const uint8_t *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint8_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint16_t, int32_t>(
|
||||||
|
const uint16_t *input, const int32_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint16_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint16_t, int64_t>(
|
||||||
|
const uint16_t *input, const int64_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint16_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint32_t, int32_t>(
|
||||||
|
const uint32_t *input, const int32_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint32_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint32_t, int64_t>(
|
||||||
|
const uint32_t *input, const int64_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint32_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint64_t, int32_t>(
|
||||||
|
const uint64_t *input, const int32_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint64_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<uint64_t, int64_t>(
|
||||||
|
const uint64_t *input, const int64_t *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, uint64_t *output, const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int8_t, int32_t>(const int8_t *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int8_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int8_t, int64_t>(const int8_t *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int8_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int16_t, int32_t>(const int16_t *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int16_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int16_t, int64_t>(const int16_t *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int16_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int32_t, int32_t>(const int32_t *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int32_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int32_t, int64_t>(const int32_t *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int32_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int64_t, int32_t>(const int64_t *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int64_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<int64_t, int64_t>(const int64_t *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, int64_t *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<half, int32_t>(const half *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, half *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<half, int64_t>(const half *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, half *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<float, int32_t>(const float *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, float *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<float, int64_t>(const float *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, float *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<double, int32_t>(const double *input, const int32_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, double *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template CUDA_LIB_EXPORT void CalMaxUnpool2D<double, int64_t>(const double *input, const int64_t *indices,
|
||||||
|
const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, double *output,
|
||||||
|
const int64_t outer_size, const int64_t thread_size,
|
||||||
|
const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,30 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXUNPOOL2D_IMPL_CUH_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MAXUNPOOL2D_IMPL_CUH_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||||
|
#include "include/cuda_fp16.h"
|
||||||
|
|
||||||
|
template <typename T, typename S>
|
||||||
|
CUDA_LIB_EXPORT void CalMaxUnpool2D(const T *input, const S *indices, const std::vector<int64_t> input_shape,
|
||||||
|
const std::vector<int64_t> output_shape, T *output, const int64_t outer_size,
|
||||||
|
const int64_t thread_size, const std::string data_format_,
|
||||||
|
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ARGMAX_IMPL_CUH_
|
|
@ -0,0 +1,138 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/gpu/kernel/nn/maxunpool2d_gpu_kernel.h"
|
||||||
|
#include <utility>
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
template <typename T, typename S>
|
||||||
|
//
|
||||||
|
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateMaxUnpool2DKernelPtr(const std::string &kernel_name,
|
||||||
|
const uint32_t &device_id) {
|
||||||
|
return std::make_unique<cukernel::MaxUnpool2DHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||||
|
}
|
||||||
|
using MaxUnpool2DPtrCreatorFunc =
|
||||||
|
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||||
|
|
||||||
|
const std::vector<std::pair<KernelAttr, MaxUnpool2DPtrCreatorFunc>> kernel_attr = {
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint8_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint8_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint16_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint16_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint32_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint32_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint64_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<uint64_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int8_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int8_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int16_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int16_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int32_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int32_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int64_t, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<int64_t, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<half, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
CreateMaxUnpool2DKernelPtr<half, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<float, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
CreateMaxUnpool2DKernelPtr<float, int64_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<double, int32_t>},
|
||||||
|
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
CreateMaxUnpool2DKernelPtr<double, int64_t>}};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool MaxUnpool2DGPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||||
|
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||||
|
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||||
|
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MaxUnpool2DGPUKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
auto kernel_ptr = std::dynamic_pointer_cast<ops::MaxUnpool2D>(base_operator);
|
||||||
|
kernel_name_ = kernel_ptr->name();
|
||||||
|
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
attr_ptr_->data_format = kernel_ptr->get_format();
|
||||||
|
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||||
|
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MaxUnpool2DGPUKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
auto input_shape = input->GetShapeVector();
|
||||||
|
if (!IsValidShape(input_shape)) {
|
||||||
|
return KRET_UNKNOWN_SHAPE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::vector<int64_t>> input_shapes;
|
||||||
|
std::vector<std::vector<int64_t>> output_shapes;
|
||||||
|
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
|
||||||
|
std::vector<int64_t> indices_shape = inputs[1]->GetShapeVector();
|
||||||
|
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||||
|
input_shapes.emplace_back(inp_shape);
|
||||||
|
input_shapes.emplace_back(indices_shape);
|
||||||
|
output_shapes.emplace_back(out_shape);
|
||||||
|
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||||
|
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||||
|
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<KernelAttr> MaxUnpool2DGPUKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, MaxUnpool2DPtrCreatorFunc> &item) { return item.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MaxUnpool2D, MaxUnpool2DGPUKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_MAXUNPOOL2D_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_MAXUNPOOL2D_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include "mindspore/core/ops/max_unpool2d.h"
|
||||||
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
|
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/maxunpool2d_helper.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MaxUnpool2DGPUKernelMod : public NativeGpuKernelMod {
|
||||||
|
public:
|
||||||
|
MaxUnpool2DGPUKernelMod() { attr_ptr_ = std::make_shared<cukernel::MaxUnpool2DAttr>(); }
|
||||||
|
~MaxUnpool2DGPUKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
int Resize(
|
||||||
|
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||||
|
std::shared_ptr<cukernel::MaxUnpool2DAttr> attr_ptr_{nullptr};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_
|
|
@ -0,0 +1,232 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "plugin/device/gpu/kernel/nn/maxunpool2d_grad_gpu_kernel.h"
|
||||||
|
#include <utility>
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kInputIndex = 0;
|
||||||
|
constexpr size_t kInputGradIndex = 1;
|
||||||
|
constexpr size_t kInputIndicesIndex = 2;
|
||||||
|
template <typename T, typename S>
|
||||||
|
//
|
||||||
|
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateMaxUnpool2DGradKernelPtr(const std::string &kernel_name,
|
||||||
|
const uint32_t &device_id) {
|
||||||
|
return std::make_unique<cukernel::MaxUnpool2DGradHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||||
|
}
|
||||||
|
using MaxUnpool2DGradPtrCreatorFunc =
|
||||||
|
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||||
|
|
||||||
|
const std::vector<std::pair<KernelAttr, MaxUnpool2DGradPtrCreatorFunc>> kernel_attr = {
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint8_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeUInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt8),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint8_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint16_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeUInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint16_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint32_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeUInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint32_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint64_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeUInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeUInt64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<uint64_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int8_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt8)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt8),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int8_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int16_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int16_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int32_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int32_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int64_t, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeInt64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<int64_t, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<half, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<half, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<float, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<float, int64_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<double, int32_t>},
|
||||||
|
{KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeFloat64)
|
||||||
|
.AddInputAttr(kNumberTypeInt64)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat64),
|
||||||
|
CreateMaxUnpool2DGradKernelPtr<double, int64_t>}};
|
||||||
|
} // namespace
|
||||||
|
bool MaxUnpool2DGradGPUKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||||
|
const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||||
|
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||||
|
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||||
|
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MaxUnpool2DGradGPUKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
auto kernel_ptr = std::dynamic_pointer_cast<ops::MaxUnpool2DGrad>(base_operator);
|
||||||
|
kernel_name_ = kernel_ptr->name();
|
||||||
|
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
attr_ptr_->data_format = kernel_ptr->get_format();
|
||||||
|
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||||
|
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MaxUnpool2DGradGPUKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||||
|
const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||||
|
for (const auto &input : inputs) {
|
||||||
|
auto input_shape = input->GetShapeVector();
|
||||||
|
if (!IsValidShape(input_shape)) {
|
||||||
|
return KRET_UNKNOWN_SHAPE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<std::vector<int64_t>> input_shapes;
|
||||||
|
std::vector<std::vector<int64_t>> output_shapes;
|
||||||
|
std::vector<int64_t> inp_shape = inputs[kInputIndex]->GetShapeVector();
|
||||||
|
std::vector<int64_t> grad_shape = inputs[kInputGradIndex]->GetShapeVector();
|
||||||
|
std::vector<int64_t> indices_shape = inputs[kInputIndicesIndex]->GetShapeVector();
|
||||||
|
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||||
|
input_shapes.emplace_back(inp_shape);
|
||||||
|
input_shapes.emplace_back(grad_shape);
|
||||||
|
input_shapes.emplace_back(indices_shape);
|
||||||
|
output_shapes.emplace_back(out_shape);
|
||||||
|
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||||
|
return KRET_RESIZE_FAILED;
|
||||||
|
}
|
||||||
|
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||||
|
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||||
|
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||||
|
return KRET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<KernelAttr> MaxUnpool2DGradGPUKernelMod::GetOpSupport() {
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, MaxUnpool2DGradPtrCreatorFunc> &item) { return item.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MaxUnpool2DGrad, MaxUnpool2DGradGPUKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_MAXUNPOOL2D_GRAD_GPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_MAXUNPOOL2D_GRAD_GPU_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include "mindspore/core/ops/grad/max_unpool2d_grad.h"
|
||||||
|
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||||
|
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||||
|
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/maxunpool2d_helper.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MaxUnpool2DGradGPUKernelMod : public NativeGpuKernelMod {
|
||||||
|
public:
|
||||||
|
MaxUnpool2DGradGPUKernelMod() { attr_ptr_ = std::make_shared<cukernel::MaxUnpool2DGradAttr>(); }
|
||||||
|
~MaxUnpool2DGradGPUKernelMod() override = default;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) override;
|
||||||
|
int Resize(
|
||||||
|
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||||
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||||
|
std::shared_ptr<cukernel::MaxUnpool2DGradAttr> attr_ptr_{nullptr};
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_ARGMAX_GPU_KERNEL_H_
|
|
@ -73,6 +73,10 @@ AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const
|
||||||
auto infer_shape = MaxUnpool2DGradInferShape(primitive, input_args);
|
auto infer_shape = MaxUnpool2DGradInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
|
std::string MaxUnpool2DGrad::get_format() const {
|
||||||
|
auto value_ptr = GetAttr("format");
|
||||||
|
return GetValue<std::string>(value_ptr);
|
||||||
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2DGrad, prim::kPrimMaxUnpool2DGrad, MaxUnpool2DGradInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2DGrad, prim::kPrimMaxUnpool2DGrad, MaxUnpool2DGradInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_
|
#define MINDSPORE_CORE_OPS_MAXUNPOOL2DGRAD_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
#include "mindapi/base/types.h"
|
#include "mindapi/base/types.h"
|
||||||
|
@ -29,6 +30,7 @@ class MIND_API MaxUnpool2DGrad : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(MaxUnpool2DGrad);
|
MIND_API_BASE_MEMBER(MaxUnpool2DGrad);
|
||||||
MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
MaxUnpool2DGrad() : BaseOperator(kNameMaxUnpool2DGrad) { InitIOName({"x", "grads", "argmax"}, {"y"}); }
|
||||||
|
std::string get_format() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr MaxUnpool2DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -37,9 +37,6 @@ abstract::ShapePtr MaxUnpool2DInferShapeCompute(const std::string &data_format,
|
||||||
int64_t out_w = static_cast<int64_t>((in_shape[kInputIndex3] - 1) * strides[kInputIndex3] - 2 * pads[kInputIndex3] +
|
int64_t out_w = static_cast<int64_t>((in_shape[kInputIndex3] - 1) * strides[kInputIndex3] - 2 * pads[kInputIndex3] +
|
||||||
ksize[kInputIndex3]);
|
ksize[kInputIndex3]);
|
||||||
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], in_shape[kInputIndex1], out_h, out_w};
|
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], in_shape[kInputIndex1], out_h, out_w};
|
||||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
|
||||||
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
|
|
||||||
}
|
|
||||||
if (attr_output_shape.size() == kDim4) {
|
if (attr_output_shape.size() == kDim4) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||||
in_shape[kInputIndex0], op_name);
|
in_shape[kInputIndex0], op_name);
|
||||||
|
@ -71,9 +68,6 @@ abstract::ShapePtr MaxUnpool2DInferShapeCompute(const std::string &data_format,
|
||||||
ksize[kInputIndex2]);
|
ksize[kInputIndex2]);
|
||||||
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], out_h, out_w, in_shape[kInputIndex3]};
|
std::vector<int64_t> out_shape = {in_shape[kInputIndex0], out_h, out_w, in_shape[kInputIndex3]};
|
||||||
|
|
||||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t a) { return a <= 0; })) {
|
|
||||||
MS_LOG(EXCEPTION) << "MaxUnpool2D: Output size is not valid.";
|
|
||||||
}
|
|
||||||
if (attr_output_shape.size() == kDim4) {
|
if (attr_output_shape.size() == kDim4) {
|
||||||
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
(void)CheckAndConvertUtils::CheckInteger("output_shape[0]", attr_output_shape[kInputIndex0], kEqual,
|
||||||
in_shape[kInputIndex0], op_name);
|
in_shape[kInputIndex0], op_name);
|
||||||
|
@ -153,6 +147,11 @@ AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const Prim
|
||||||
auto infer_shape = MaxUnpool2DInferShape(primitive, input_args);
|
auto infer_shape = MaxUnpool2DInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string MaxUnpool2D::get_format() const {
|
||||||
|
auto value_ptr = GetAttr("format");
|
||||||
|
return GetValue<std::string>(value_ptr);
|
||||||
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2D, prim::kPrimMaxUnpool2D, MaxUnpool2DInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(MaxUnpool2D, prim::kPrimMaxUnpool2D, MaxUnpool2DInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_
|
#define MINDSPORE_CORE_OPS_MAXUNPOOL2D_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "ops/base_operator.h"
|
#include "ops/base_operator.h"
|
||||||
#include "mindapi/base/types.h"
|
#include "mindapi/base/types.h"
|
||||||
|
@ -29,6 +30,7 @@ class MIND_API MaxUnpool2D : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(MaxUnpool2D);
|
MIND_API_BASE_MEMBER(MaxUnpool2D);
|
||||||
MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
|
MaxUnpool2D() : BaseOperator(kNameMaxUnpool2D) { InitIOName({"x", "argmax"}, {"y"}); }
|
||||||
|
std::string get_format() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
abstract::AbstractBasePtr MaxUnpool2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore import Tensor
|
||||||
|
import mindspore.ops.operations.nn_ops as ops
|
||||||
|
import mindspore.ops.operations._grad_ops as grad_ops
|
||||||
|
|
||||||
|
|
||||||
|
class NetMaxUnpool2DFourD(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetMaxUnpool2DFourD, self).__init__()
|
||||||
|
self.maxunpool2d_fun = ops.MaxUnpool2D(ksize=(3, 2), strides=(3, 2), pads=0, data_format='NCHW')
|
||||||
|
|
||||||
|
def construct(self, x, indices):
|
||||||
|
return self.maxunpool2d_fun(x, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class NetMaxUnpool2DGradFourD(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetMaxUnpool2DGradFourD, self).__init__()
|
||||||
|
self.maxunpool2d_grad = grad_ops.MaxUnpool2DGrad(ksize=(1, 1, 3, 2), strides=(1, 1, 3, 2), pads=(1, 1, 0, 0),
|
||||||
|
data_format='NCHW')
|
||||||
|
|
||||||
|
def construct(self, x, grad, indices):
|
||||||
|
return self.maxunpool2d_grad(x, grad, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class NetMaxUnpool2DFourDNHWC(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetMaxUnpool2DFourDNHWC, self).__init__()
|
||||||
|
self.maxunpool2d_fun = ops.MaxUnpool2D(ksize=(3, 2), strides=(3, 2), pads=0, data_format='NHWC')
|
||||||
|
|
||||||
|
def construct(self, x, indices):
|
||||||
|
return self.maxunpool2d_fun(x, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class NetMaxUnpool2DGradFourDNHWC(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetMaxUnpool2DGradFourDNHWC, self).__init__()
|
||||||
|
self.maxunpool2d_grad = grad_ops.MaxUnpool2DGrad(ksize=(1, 1, 3, 2), strides=(1, 1, 3, 2), pads=(1, 1, 0, 0),
|
||||||
|
data_format='NHWC')
|
||||||
|
|
||||||
|
def construct(self, x, grad, indices):
|
||||||
|
return self.maxunpool2d_grad(x, grad, indices)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_maxunpool2d_4dinput_graph():
|
||||||
|
"""
|
||||||
|
Feature: MaxUnpool2d 4dinput graph
|
||||||
|
Description: 4dinput graph
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||||
|
indices_type = [np.int32, np.int64]
|
||||||
|
inout_types = [np.int8, np.int16, np.int32, np.int64,
|
||||||
|
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64]
|
||||||
|
for indices_type_i in indices_type:
|
||||||
|
for inout_type_i in inout_types:
|
||||||
|
x = Tensor(np.array([[[[10, 12], [22, 24]],
|
||||||
|
[[34, 36], [46, 48]]],
|
||||||
|
[[[58, 60], [70, 72]],
|
||||||
|
[[82, 84], [94, 96]]]]).astype(inout_type_i))
|
||||||
|
indices = Tensor(np.array([[[[9, 11], [21, 23]],
|
||||||
|
[[9, 11], [21, 23]]],
|
||||||
|
[[[9, 11], [21, 23]],
|
||||||
|
[[9, 11], [21, 23]]]]).astype(indices_type_i))
|
||||||
|
maxunpool2d = NetMaxUnpool2DFourD()
|
||||||
|
y = maxunpool2d(x, indices)
|
||||||
|
output_type = y.asnumpy().dtype
|
||||||
|
expect_result = Tensor(np.array([[[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 10, 0, 12], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 22, 0, 24]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 34, 0, 36], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 46, 0, 48]]],
|
||||||
|
[[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 58, 0, 60], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 70, 0, 72]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 82, 0, 84], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 94, 0, 96]]]]).astype(inout_type_i))
|
||||||
|
assert np.allclose(expect_result.asnumpy(), y.asnumpy())
|
||||||
|
assert output_type == inout_type_i
|
||||||
|
|
||||||
|
maxunpoo2dgrad = NetMaxUnpool2DGradFourD()
|
||||||
|
grad = Tensor(np.array([i+1 for i in range(4*24)]).reshape([2, 2, 6, 4]).astype(inout_type_i))
|
||||||
|
output_grad = maxunpoo2dgrad(x, grad, indices)
|
||||||
|
output_grad_type = output_grad.asnumpy().dtype
|
||||||
|
expect_output_grad = Tensor(np.array([[[[10, 12], [22, 24]],
|
||||||
|
[[34, 36], [46, 48]]],
|
||||||
|
[[[58, 60], [70, 72]],
|
||||||
|
[[82, 84], [94, 96]]]]).astype(inout_type_i))
|
||||||
|
assert np.allclose(expect_output_grad.asnumpy(), output_grad.asnumpy())
|
||||||
|
assert output_grad_type == inout_type_i
|
||||||
|
|
||||||
|
|
||||||
|
def test_maxunpool2d_4dinput_pynative():
|
||||||
|
"""
|
||||||
|
Feature: MaxUnpool2d 4dinput pynative
|
||||||
|
Description: 4dinput pynative
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||||
|
indices_type = [np.int32, np.int64]
|
||||||
|
inout_types = [np.int8, np.int16, np.int32, np.int64,
|
||||||
|
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64]
|
||||||
|
for indices_type_i in indices_type:
|
||||||
|
for inout_type_i in inout_types:
|
||||||
|
x = Tensor(np.array([[[[10, 12], [22, 24]],
|
||||||
|
[[34, 36], [46, 48]]],
|
||||||
|
[[[58, 60], [70, 72]],
|
||||||
|
[[82, 84], [94, 96]]]]).astype(inout_type_i)).transpose(0, 2, 3, 1)
|
||||||
|
indices = Tensor(np.array([[[[9, 11], [21, 23]],
|
||||||
|
[[9, 11], [21, 23]]],
|
||||||
|
[[[9, 11], [21, 23]],
|
||||||
|
[[9, 11], [21, 23]]]]).astype(indices_type_i)).transpose(0, 2, 3, 1)
|
||||||
|
maxunpool2d = NetMaxUnpool2DFourDNHWC()
|
||||||
|
y = maxunpool2d(x, indices)
|
||||||
|
output_type = y.asnumpy().dtype
|
||||||
|
expect_result = Tensor(np.array([[[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 10, 0, 12], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 22, 0, 24]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 34, 0, 36], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 46, 0, 48]]],
|
||||||
|
[[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 58, 0, 60], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 70, 0, 72]],
|
||||||
|
[[0, 0, 0, 0], [0, 0, 0, 0],
|
||||||
|
[0, 82, 0, 84], [0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0], [0, 94, 0, 96]]]])
|
||||||
|
.astype(inout_type_i)).transpose(0, 2, 3, 1)
|
||||||
|
assert np.allclose(expect_result.asnumpy(), y.asnumpy())
|
||||||
|
assert output_type == inout_type_i
|
||||||
|
|
||||||
|
maxunpoo2dgrad = NetMaxUnpool2DGradFourDNHWC()
|
||||||
|
grad = Tensor(np.array([i+1 for i in range(4*24)]).reshape([2, 2, 6, 4])
|
||||||
|
.astype(inout_type_i)).transpose(0, 2, 3, 1)
|
||||||
|
output_grad = maxunpoo2dgrad(x, grad, indices)
|
||||||
|
output_grad_type = output_grad.asnumpy().dtype
|
||||||
|
expect_output_grad = Tensor(np.array([[[[10, 12], [22, 24]],
|
||||||
|
[[34, 36], [46, 48]]],
|
||||||
|
[[[58, 60], [70, 72]],
|
||||||
|
[[82, 84], [94, 96]]]]).astype(inout_type_i)).transpose(0, 2, 3, 1)
|
||||||
|
assert np.allclose(expect_output_grad.asnumpy(), output_grad.asnumpy())
|
||||||
|
assert output_grad_type == inout_type_i
|
Loading…
Reference in New Issue