add adaptive maxpool3d grad gpu op and testcase
This commit is contained in:
parent
9e1f49bff2
commit
81e556e444
|
@ -14,13 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL2D_GRAD_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL2D_GRAD_HELPER_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
|
@ -28,66 +29,39 @@ constexpr int64_t maxIndexIdx = 2;
|
|||
constexpr int64_t dyDimSmall = 3;
|
||||
constexpr int64_t hIdx = 2;
|
||||
|
||||
class AdaptiveMaxPool2DGradAttr : public GpuKernelAttrBase {
|
||||
class AdaptiveMaxPoolGradAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
AdaptiveMaxPool2DGradAttr() = default;
|
||||
~AdaptiveMaxPool2DGradAttr() override = default;
|
||||
AdaptiveMaxPoolGradAttr() = default;
|
||||
~AdaptiveMaxPoolGradAttr() override = default;
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
class AdaptiveMaxPool2DGradHelperGpuKernel : public GpuKernelHelperBase {
|
||||
class AdaptiveMaxPoolGradHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit AdaptiveMaxPool2DGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
explicit AdaptiveMaxPoolGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~AdaptiveMaxPool2DGradHelperGpuKernel() = default;
|
||||
virtual ~AdaptiveMaxPoolGradHelperGpuKernel() = default;
|
||||
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
ResetResource();
|
||||
|
||||
// cal input_size_list_ (dy, x, index)
|
||||
size_t dy_size = sizeof(T);
|
||||
for (auto val : input_shapes[0]) {
|
||||
dy_size *= val;
|
||||
is_null_input_ = CHECK_SHAPE_NULL(output_shapes[0], kernel_name_, "out_shape");
|
||||
if (is_null_input_) {
|
||||
return -1;
|
||||
}
|
||||
input_size_list_.emplace_back(dy_size);
|
||||
|
||||
size_t x_size = sizeof(T);
|
||||
for (auto val : input_shapes[1]) {
|
||||
x_size *= val;
|
||||
}
|
||||
input_size_list_.emplace_back(x_size);
|
||||
|
||||
size_t index_size = sizeof(S);
|
||||
for (auto val : input_shapes[maxIndexIdx]) {
|
||||
index_size *= val;
|
||||
}
|
||||
input_size_list_.emplace_back(index_size);
|
||||
|
||||
// cal output_size_list_ (dx)
|
||||
int out_flag = CalShapesSizeInBytes<T>(output_shapes, 1, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
|
||||
input_shape_.emplace_back(input_shapes[0]); // dy
|
||||
input_shape_.emplace_back(input_shapes[1]); // x
|
||||
input_shape_.emplace_back(input_shapes[maxIndexIdx]); // index
|
||||
output_shape_ = output_shapes[0]; // dx
|
||||
|
||||
is_null_input_ = (out_flag == 1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get device ptr input index output
|
||||
T *dy_ptr = nullptr;
|
||||
S *index_ptr = nullptr;
|
||||
|
@ -107,6 +81,17 @@ class AdaptiveMaxPool2DGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
return flag;
|
||||
}
|
||||
|
||||
if (kernel_name_ == kAdaptiveMaxPool3DGradOpName) {
|
||||
const int64_t output_stride = output_shape_.cend()[-1] * output_shape_.cend()[-2] * output_shape_.cend()[-3];
|
||||
auto input_argmax_shape = input_shape_[maxIndexIdx];
|
||||
const int64_t argmax_stride =
|
||||
input_argmax_shape.cend()[-1] * input_argmax_shape.cend()[-2] * input_argmax_shape.cend()[-3];
|
||||
const int64_t batch = std::accumulate(input_argmax_shape.begin(), input_argmax_shape.end() - 3,
|
||||
static_cast<int64_t>(1), [=](int64_t a, int64_t b) { return a * b; });
|
||||
CalAdaptiveMaxPool3DGrad(dy_ptr, index_ptr, output_stride, argmax_stride, batch, dx_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
// call cuda kernel
|
||||
const int shape_dim = output_shape_.size(); // dx grad dim 3 or 4
|
||||
auto input_shape = input_shape_[0]; // dy
|
||||
|
@ -128,23 +113,20 @@ class AdaptiveMaxPool2DGradHelperGpuKernel : public GpuKernelHelperBase {
|
|||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<AdaptiveMaxPool2DGradAttr>(kernel_attr);
|
||||
attr_ptr_ = std::dynamic_pointer_cast<AdaptiveMaxPoolGradAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
void ResetResource() override {
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
work_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<AdaptiveMaxPool2DGradAttr> attr_ptr_;
|
||||
std::shared_ptr<AdaptiveMaxPoolGradAttr> attr_ptr_;
|
||||
std::vector<std::vector<int64_t>> input_shape_; // 0:input_shape(y_grad) 2:index_shape(argmax)
|
||||
std::vector<int64_t> output_shape_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL2D_GRAD_HELPER_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ADAPTIVE_MAX_POOL_GRAD_HELPER_H_
|
|
@ -1,59 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void AdaptiveMaxPool2DGradKernel(const T *input_data, const S *max_index, const int input_nchw,
|
||||
const int input_hw, const int output_hw, T *output_data) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_nchw; pos += blockDim.x * gridDim.x) {
|
||||
const S idx = max_index[pos];
|
||||
const int posn = pos / input_hw;
|
||||
MsAtomicAdd(output_data + posn * output_hw + static_cast<int>(idx), input_data[pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalAdaptiveMaxPool2DGrad(const T *input_data, const S *max_index, const int n, const int c,
|
||||
const uint input_height, const uint input_width,
|
||||
const uint output_height, const uint output_width, T *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const int input_hw = input_height * input_width;
|
||||
const int input_chw = c * input_hw;
|
||||
const int input_nchw = n * input_chw;
|
||||
const int output_hw = output_height * output_width;
|
||||
|
||||
AdaptiveMaxPool2DGradKernel<<<CUDA_BLOCKS(device_id, input_nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input_data, max_index, input_nchw, input_hw, output_hw, output_data);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<half, int64_t>(
|
||||
const half *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, half *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<float, int64_t>(
|
||||
const float *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, float *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<double, int64_t>(
|
||||
const double *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, double *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool2d_grad_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void AdaptiveMaxPool2DGradKernel(const T *input_data, const S *max_index, const int input_nchw,
|
||||
const int input_hw, const int output_hw, T *output_data) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_nchw; pos += blockDim.x * gridDim.x) {
|
||||
const S idx = max_index[pos];
|
||||
const int posn = pos / input_hw;
|
||||
MsAtomicAdd(output_data + posn * output_hw + static_cast<int>(idx), input_data[pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalAdaptiveMaxPool2DGrad(const T *input_data, const S *max_index, const int n, const int c,
|
||||
const uint input_height, const uint input_width,
|
||||
const uint output_height, const uint output_width, T *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const int input_hw = input_height * input_width;
|
||||
const int input_chw = c * input_hw;
|
||||
const int input_nchw = n * input_chw;
|
||||
const int output_hw = output_height * output_width;
|
||||
|
||||
AdaptiveMaxPool2DGradKernel<<<CUDA_BLOCKS(device_id, input_nchw), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input_data, max_index, input_nchw, input_hw, output_hw, output_data);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<half, int>(
|
||||
const half *input_data, const int *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, half *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<float, int>(
|
||||
const float *input_data, const int *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, float *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<double, int>(
|
||||
const double *input_data, const int *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, double *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<half, int64_t>(
|
||||
const half *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, half *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<float, int64_t>(
|
||||
const float *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, float *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool2DGrad<double, int64_t>(
|
||||
const double *input_data, const int64_t *max_index, const int n, const int c, const uint input_height,
|
||||
const uint input_width, const uint output_height, const uint output_width, double *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/adaptive_max_pool3d_grad_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void AdaptiveMaxPool3DGradKernel(const T *input_grad, const S *input_argmax, const int output_stride,
|
||||
const int argmax_stride, const int batch, T *output_data) {
|
||||
for (size_t n = blockIdx.x * blockDim.x + threadIdx.x; n < batch; n += blockDim.x * gridDim.x) {
|
||||
for (int64_t i = 0; i < argmax_stride; ++i) {
|
||||
int32_t maxp = input_argmax[i + n * argmax_stride] + n * output_stride;
|
||||
MsAtomicAdd(output_data + static_cast<int>(maxp), input_grad[i + n * argmax_stride]);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, const int output_stride,
|
||||
const int argmax_stride, const int batch, T *output_data, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
AdaptiveMaxPool3DGradKernel<<<CUDA_BLOCKS(device_id, batch), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input_grad, input_argmax, output_stride, argmax_stride, batch, output_data);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<half, int>(const half *input_grad, const int *input_argmax,
|
||||
const int output_stride, const int argmax_stride,
|
||||
const int batch, half *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<float, int>(const float *input_grad, const int *input_argmax,
|
||||
const int output_stride, const int argmax_stride,
|
||||
const int batch, float *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<double, int>(const double *input_grad, const int *input_argmax,
|
||||
const int output_stride, const int argmax_stride,
|
||||
const int batch, double *output_data,
|
||||
const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<half, int64_t>(
|
||||
const half *input_grad, const int64_t *input_argmax, const int output_stride, const int argmax_stride,
|
||||
const int batch, half *output_data, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<float, int64_t>(
|
||||
const float *input_grad, const int64_t *input_argmax, const int output_stride, const int argmax_stride,
|
||||
const int batch, float *output_data, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad<double, int64_t>(
|
||||
const double *input_grad, const int64_t *input_argmax, const int output_stride, const int argmax_stride,
|
||||
const int batch, double *output_data, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void CalAdaptiveMaxPool3DGrad(const T *input_grad, const S *input_argmax, const int output_stride,
|
||||
const int argmax_stride, const int batch, T *output_data,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_ADAPTIVE_MAX_POOL3D_GRAD_IMPL_CUH_
|
|
@ -1,132 +1,129 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/adaptive_max_pool2d_grad_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int64_t maxIndexIdx = 2;
|
||||
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateAdaptiveMaxPool2DGradKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::AdaptiveMaxPool2DGradHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
|
||||
using AdaptiveMaxPool2DGradPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, AdaptiveMaxPool2DGradPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<half, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<float, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<double, int64_t>},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool AdaptiveMaxPool2DGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"failed to set cuda memory with zeros.");
|
||||
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AdaptiveMaxPool2DGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::AdaptiveMaxPool2DGrad>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int AdaptiveMaxPool2DGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> x_shape = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> index_shape = inputs[maxIndexIdx]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
|
||||
(void)input_shapes.emplace_back(input_shape);
|
||||
(void)input_shapes.emplace_back(x_shape);
|
||||
(void)input_shapes.emplace_back(index_shape);
|
||||
(void)output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AdaptiveMaxPool2DGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, AdaptiveMaxPool2DGradPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveMaxPool2DGrad, AdaptiveMaxPool2DGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/adaptive_max_pool2d_grad_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int64_t maxIndexIdx = 2;
|
||||
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateAdaptiveMaxPool2DGradKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::AdaptiveMaxPoolGradHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
|
||||
using AdaptiveMaxPool2DGradPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, AdaptiveMaxPool2DGradPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<half, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<float, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateAdaptiveMaxPool2DGradKernelPtr<double, int64_t>},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool AdaptiveMaxPool2DGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"failed to set cuda memory with zeros.");
|
||||
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AdaptiveMaxPool2DGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::AdaptiveMaxPool2DGrad>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int AdaptiveMaxPool2DGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> x_shape = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> index_shape = inputs[maxIndexIdx]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
|
||||
(void)input_shapes.emplace_back(input_shape);
|
||||
(void)input_shapes.emplace_back(x_shape);
|
||||
(void)input_shapes.emplace_back(index_shape);
|
||||
(void)output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AdaptiveMaxPool2DGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, AdaptiveMaxPool2DGradPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveMaxPool2DGrad, AdaptiveMaxPool2DGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,57 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/grad/adaptive_max_pool2d_grad.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool2d_grad_helper.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdaptiveMaxPool2DGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AdaptiveMaxPool2DGradGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::AdaptiveMaxPool2DGradAttr>(); }
|
||||
~AdaptiveMaxPool2DGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::AdaptiveMaxPool2DGradAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/grad/adaptive_max_pool2d_grad.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdaptiveMaxPool2DGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AdaptiveMaxPool2DGradGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::AdaptiveMaxPoolGradAttr>(); }
|
||||
~AdaptiveMaxPool2DGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::AdaptiveMaxPoolGradAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL2D_GRAD_GPU_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/adaptive_max_pool3d_grad_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int64_t maxIndexIdx = 2;
|
||||
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateAdaptiveMaxPoolGradKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::AdaptiveMaxPoolGradHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
|
||||
using AdaptiveMaxPoolGradPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, AdaptiveMaxPoolGradPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<half, int>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<float, int>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<double, int>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<half, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<float, int64_t>},
|
||||
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateAdaptiveMaxPoolGradKernelPtr<double, int64_t>},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool AdaptiveMaxPool3DGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemsetAsync(output_ptrs[0], 0, outputs[0]->size, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"failed to set cuda memory with zeros.");
|
||||
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AdaptiveMaxPool3DGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int AdaptiveMaxPool3DGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> x_shape = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> index_shape = inputs[maxIndexIdx]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
|
||||
(void)input_shapes.emplace_back(input_shape);
|
||||
(void)input_shapes.emplace_back(x_shape);
|
||||
(void)input_shapes.emplace_back(index_shape);
|
||||
(void)output_shapes.emplace_back(out_shape);
|
||||
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AdaptiveMaxPool3DGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, AdaptiveMaxPoolGradPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AdaptiveMaxPool3DGrad, AdaptiveMaxPool3DGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "mindspore/core/ops/grad/adaptive_max_pool_3d_grad.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/adaptive_max_pool_grad_helper.h"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class AdaptiveMaxPool3DGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
AdaptiveMaxPool3DGradGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::AdaptiveMaxPoolGradAttr>(); }
|
||||
~AdaptiveMaxPool3DGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::AdaptiveMaxPoolGradAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_ADAPTIVE_MAX_POOL3D_GRAD_GPU_KERNEL_H_
|
|
@ -69,7 +69,7 @@ TypePtr AdaptiveMaxPool3DGradInferType(const PrimitivePtr &, const std::vector<A
|
|||
auto argmax_dtype = input_args[2]->BuildType();
|
||||
const std::set<TypePtr> real_number_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> argmax_valid_types = {kInt32};
|
||||
const std::set<TypePtr> argmax_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_grad_dtype", input_grad_dtype, real_number_types,
|
||||
kNameAdaptiveMaxPool3DGrad);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, real_number_types, kNameAdaptiveMaxPool3DGrad);
|
||||
|
|
|
@ -3305,6 +3305,7 @@ class AdaptiveMaxPool3DGrad(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize AdaptiveMaxPool3DGrad"""
|
||||
self.init_prim_io_names(inputs=['input_grad', 'x', 'argmax'], outputs=['output_grad'])
|
||||
|
||||
|
||||
class TraceGrad(Primitive):
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetAdaptiveMaxPool3DGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetAdaptiveMaxPool3DGrad, self).__init__()
|
||||
self.adaptive_max_pool3d_grad_fun = G.AdaptiveMaxPool3DGrad()
|
||||
|
||||
def construct(self, dy, x, argmax):
|
||||
return self.adaptive_max_pool3d_grad_fun(dy, x, argmax)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_adaptive_max_pool3d_grad_fp32():
|
||||
"""
|
||||
Feature: test adaptivemaxpool3dgrad op.
|
||||
Description: test the ops.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
x = Tensor(np.array([[
|
||||
[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11],
|
||||
[12, 13, 14, 15]
|
||||
],
|
||||
[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11],
|
||||
[12, 13, 14, 15]
|
||||
]
|
||||
]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[
|
||||
[
|
||||
[0.7, 0.9],
|
||||
[0.19, 0.21]
|
||||
],
|
||||
[
|
||||
[0.7, 0.9],
|
||||
[0.19, 0.21]
|
||||
],
|
||||
]]).astype(np.float32))
|
||||
index = Tensor(np.array([[
|
||||
[
|
||||
[5, 7],
|
||||
[13, 15]
|
||||
],
|
||||
[
|
||||
[21, 23],
|
||||
[29, 31]
|
||||
]
|
||||
]]).astype(np.int))
|
||||
expect_result = (np.array([[
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.7, 0., 0.9],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.19, 0., 0.21],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.7, 0., 0.9],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.19, 0., 0.21],
|
||||
],
|
||||
]]))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
adaptive_max_pool3d_grad = NetAdaptiveMaxPool3DGrad()
|
||||
output = adaptive_max_pool3d_grad(dy, x, index)
|
||||
assert np.allclose(expect_result, output.asnumpy())
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
adaptive_max_pool3d_grad = NetAdaptiveMaxPool3DGrad()
|
||||
output = adaptive_max_pool3d_grad(dy, x, index)
|
||||
assert np.allclose(expect_result, output.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_adaptive_max_pool3d_grad_fp64():
|
||||
"""
|
||||
Feature: test adaptivemaxpool3dgrad op.
|
||||
Description: test the ops.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
x = Tensor(np.array([[
|
||||
[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11],
|
||||
[12, 13, 14, 15]
|
||||
],
|
||||
[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11],
|
||||
[12, 13, 14, 15]
|
||||
]
|
||||
]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[
|
||||
[
|
||||
[0.7, 0.9],
|
||||
[0.19, 0.21]
|
||||
],
|
||||
[
|
||||
[0.7, 0.9],
|
||||
[0.19, 0.21]
|
||||
],
|
||||
]]).astype(np.float32))
|
||||
index = Tensor(np.array([[
|
||||
[
|
||||
[5, 7],
|
||||
[13, 15]
|
||||
],
|
||||
[
|
||||
[21, 23],
|
||||
[29, 31]
|
||||
]
|
||||
]]).astype(np.int))
|
||||
expect_result = (np.array([[
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.7, 0., 0.9],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.19, 0., 0.21],
|
||||
],
|
||||
[
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.7, 0., 0.9],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0.19, 0., 0.21],
|
||||
],
|
||||
]]))
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
adaptive_max_pool3d_grad = NetAdaptiveMaxPool3DGrad()
|
||||
output = adaptive_max_pool3d_grad(dy, x, index)
|
||||
assert np.allclose(expect_result, output.asnumpy())
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
adaptive_max_pool3d_grad = NetAdaptiveMaxPool3DGrad()
|
||||
output = adaptive_max_pool3d_grad(dy, x, index)
|
||||
assert np.allclose(expect_result, output.asnumpy())
|
Loading…
Reference in New Issue