!35402 [assistant][ops] New GPU operator implementation, include FractionalMaxPool3DWithFixedKsize, FractionalMaxPool3DGradWithFixedKsize

Merge pull request !35402 from 黎冠新/FractionalMaxPool3DWithFixedKsize
This commit is contained in:
i-robot 2022-07-19 09:40:31 +00:00 committed by Gitee
commit cf17b1134a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 1681 additions and 14 deletions

View File

@ -0,0 +1,249 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <random>
#include <utility>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fractionalmaxpool3dgradwithfixedksize_impl.cuh"
namespace mindspore {
namespace cukernel {
constexpr size_t kDimSize4 = 4;
constexpr size_t kDimSize5 = 5;
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kOutputIndex0 = 0;
constexpr size_t kOutputIndex1 = 1;
constexpr size_t kFormatNCDHWIndexC = 0;
constexpr size_t kFormatNCDHWIndexD = 1;
constexpr size_t kFormatNCDHWIndexH = 2;
constexpr size_t kFormatNCDHWIndexW = 3;
constexpr size_t kFormatNDHWCIndexD = 0;
constexpr size_t kFormatNDHWCIndexH = 1;
constexpr size_t kFormatNDHWCIndexW = 2;
constexpr size_t kFormatNDHWCIndexC = 3;
constexpr size_t kInputsNum = 3;
constexpr size_t kOutputsNum = 1;
class FractionalMaxPool3DGradWithFixedKsizeAttr : public GpuKernelAttrBase {
public:
FractionalMaxPool3DGradWithFixedKsizeAttr() = default;
~FractionalMaxPool3DGradWithFixedKsizeAttr() override = default;
std::string data_format;
};
template <typename T, typename S>
class FractionalMaxPool3DGradWithFixedKsizeHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit FractionalMaxPool3DGradWithFixedKsizeHelperGpuKernel(const std::string &kernel_name,
const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
is_null_input_ = false;
}
virtual ~FractionalMaxPool3DGradWithFixedKsizeHelperGpuKernel() = default;
int CheckDims() {
size_t input_dims = origin_input_shape_.size();
size_t out_backprop_dims = out_backprop_shape_.size();
size_t argmax_dims = argmax_shape_.size();
if (!(input_dims == kDimSize4 || input_dims == kDimSize5)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'input' must be equal to 4 or 5, but got "
<< input_dims << ".";
return -1;
}
if (!(out_backprop_dims == kDimSize4 || out_backprop_dims == kDimSize5)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'out_backprop' must be equal to 4 or 5, but got "
<< out_backprop_dims << ".";
return -1;
}
if (!(argmax_dims == kDimSize4 || argmax_dims == kDimSize5)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'argmax' must be equal to 4 or 5, but got "
<< argmax_dims << ".";
return -1;
}
return 1;
}
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t OUTPUT_NUM = 1;
ResetResource();
origin_input_shape_ = input_shapes[kInputIndex0];
out_backprop_shape_ = input_shapes[kInputIndex1];
argmax_shape_ = input_shapes[kInputIndex2];
output_shape_ = output_shapes[kOutputIndex0];
size_t c_dim;
size_t d_dim;
size_t h_dim;
size_t w_dim;
data_format_ = attr_ptr_->data_format;
if (data_format_ != "NCDHW" && data_format_ != "NDHWC") {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', data_format must be NCDHW or NDHWC, but got " << data_format_
<< ".";
return -1;
}
if (data_format_ == "NCDHW") {
c_dim = kFormatNCDHWIndexC;
d_dim = kFormatNCDHWIndexD;
h_dim = kFormatNCDHWIndexH;
w_dim = kFormatNCDHWIndexW;
} else {
c_dim = kFormatNDHWCIndexC;
d_dim = kFormatNDHWCIndexD;
h_dim = kFormatNDHWCIndexH;
w_dim = kFormatNDHWCIndexW;
}
if (origin_input_shape_.size() == kDimSize5) {
inputN_ = origin_input_shape_[0];
c_dim++;
d_dim++;
h_dim++;
w_dim++;
}
inputC_ = origin_input_shape_[c_dim];
inputD_ = origin_input_shape_[d_dim];
inputH_ = origin_input_shape_[h_dim];
inputW_ = origin_input_shape_[w_dim];
outputD_ = out_backprop_shape_[d_dim];
outputH_ = out_backprop_shape_[h_dim];
outputW_ = out_backprop_shape_[w_dim];
int dims_flag = CheckDims();
if (dims_flag == -1) {
return dims_flag;
}
int inp_flag = 0;
size_t cur_size_T = sizeof(T);
for (const auto &val : origin_input_shape_) {
cur_size_T *= val;
}
if (cur_size_T == 0 && inp_flag == 0) {
inp_flag = 1;
}
input_size_list_.emplace_back(cur_size_T);
cur_size_T = sizeof(T);
for (const auto &val : out_backprop_shape_) {
cur_size_T *= val;
}
if (cur_size_T == 0 && inp_flag == 0) {
inp_flag = 1;
}
input_size_list_.emplace_back(cur_size_T);
size_t cur_size_S = sizeof(S);
for (const auto &val : argmax_shape_) {
cur_size_S *= val;
}
if (cur_size_S == 0 && inp_flag == 0) {
inp_flag = 1;
}
input_size_list_.emplace_back(cur_size_S);
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag == -1) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
return 0;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *origin_input_ptr = nullptr;
T *out_backprop_ptr = nullptr;
S *argmax_ptr = nullptr;
T *output_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, kInputIndex0, kernel_name_, &origin_input_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, kInputIndex1, kernel_name_, &out_backprop_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(input_ptrs, kInputIndex2, kernel_name_, &argmax_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, kOutputIndex0, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
int64_t dims = static_cast<int64_t>(output_shape_.size());
int64_t outer_size = 1;
for (int64_t i = dims - 1; i >= 0; i--) {
outer_size *= output_shape_[i];
}
dims = static_cast<int64_t>(out_backprop_shape_.size());
int64_t out_backprop_size = 1;
for (int64_t i = dims - 1; i >= 0; i--) {
out_backprop_size *= out_backprop_shape_[i];
}
CalFractionalmaxpool3dgradwithfixedksize(origin_input_ptr, out_backprop_ptr, argmax_ptr, output_ptr, outputD_,
outputH_, outputW_, inputN_, inputC_, inputD_, inputH_, inputW_,
outer_size, out_backprop_size, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<FractionalMaxPool3DGradWithFixedKsizeAttr>(kernel_attr);
}
private:
std::string data_format_;
int64_t outputD_{1};
int64_t outputH_{1};
int64_t outputW_{1};
int64_t inputN_{1};
int64_t inputC_{1};
int64_t inputD_{1};
int64_t inputH_{1};
int64_t inputW_{1};
std::shared_ptr<FractionalMaxPool3DGradWithFixedKsizeAttr> attr_ptr_;
std::vector<int64_t> origin_input_shape_;
std::vector<int64_t> out_backprop_shape_;
std::vector<int64_t> argmax_shape_;
std::vector<int64_t> output_shape_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_

View File

@ -0,0 +1,285 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
#include <random>
#include <utility>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fractionalmaxpool3dwithfixedksize_impl.cuh"
namespace mindspore {
namespace cukernel {
constexpr size_t kDimSize2 = 2;
constexpr size_t kDimSize3 = 3;
constexpr size_t kDimSize4 = 4;
constexpr size_t kDimSize5 = 5;
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kOutputIndex0 = 0;
constexpr size_t kOutputIndex1 = 1;
constexpr size_t kkernelsizeIndexD = 0;
constexpr size_t kkernelsizeIndexH = 1;
constexpr size_t kkernelsizeIndexW = 2;
constexpr size_t kOutputshapeIndexD = 0;
constexpr size_t kOutputshapeIndexH = 1;
constexpr size_t kOutputshapeIndexW = 2;
constexpr size_t kDimSize5FormatNCDHWIndexN = 0;
constexpr size_t kDimSize5FormatNCDHWIndexC = 1;
constexpr size_t kDimSize5FormatNCDHWIndexD = 2;
constexpr size_t kDimSize5FormatNCDHWIndexH = 3;
constexpr size_t kDimSize5FormatNCDHWIndexW = 4;
constexpr size_t kDimSize5FormatNDHWCIndexN = 0;
constexpr size_t kDimSize5FormatNDHWCIndexD = 1;
constexpr size_t kDimSize5FormatNDHWCIndexH = 2;
constexpr size_t kDimSize5FormatNDHWCIndexW = 3;
constexpr size_t kDimSize5FormatNDHWCIndexC = 4;
constexpr size_t kDimSize4FormatNCDHWIndexC = 0;
constexpr size_t kDimSize4FormatNCDHWIndexD = 1;
constexpr size_t kDimSize4FormatNCDHWIndexH = 2;
constexpr size_t kDimSize4FormatNCDHWIndexW = 3;
constexpr size_t kDimSize4FormatNDHWCIndexD = 0;
constexpr size_t kDimSize4FormatNDHWCIndexH = 1;
constexpr size_t kDimSize4FormatNDHWCIndexW = 2;
constexpr size_t kDimSize4FormatNDHWCIndexC = 3;
class FractionalMaxPool3DWithFixedKsizeAttr : public GpuKernelAttrBase {
public:
FractionalMaxPool3DWithFixedKsizeAttr() = default;
~FractionalMaxPool3DWithFixedKsizeAttr() override = default;
std::vector<float> ksize;
std::vector<int64_t> output_shape;
std::string data_format;
};
template <typename T, typename S, typename G>
class FractionalMaxPool3DWithFixedKsizeHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit FractionalMaxPool3DWithFixedKsizeHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
is_null_input_ = false;
}
virtual ~FractionalMaxPool3DWithFixedKsizeHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
ResetResource();
input_shape_ = input_shapes[kInputIndex0];
random_samples_shape_ = input_shapes[kInputIndex1];
output_shape_ = output_shapes[kOutputIndex0];
argmax_shape_ = output_shapes[kOutputIndex1];
data_format_ = attr_ptr_->data_format;
if (data_format_ != "NCDHW" && data_format_ != "NDHWC") {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', data_format must be NCDHW or NDHWC, but got " << data_format_
<< ".";
return -1;
}
if (data_format_ == "NCDHW") {
if (input_shape_.size() == kDimSize5) {
inputN_ = input_shape_[kDimSize5FormatNCDHWIndexN];
inputC_ = input_shape_[kDimSize5FormatNCDHWIndexC];
inputD_ = input_shape_[kDimSize5FormatNCDHWIndexD];
inputH_ = input_shape_[kDimSize5FormatNCDHWIndexH];
inputW_ = input_shape_[kDimSize5FormatNCDHWIndexW];
} else {
inputC_ = input_shape_[kDimSize4FormatNCDHWIndexC];
inputD_ = input_shape_[kDimSize4FormatNCDHWIndexD];
inputH_ = input_shape_[kDimSize4FormatNCDHWIndexH];
inputW_ = input_shape_[kDimSize4FormatNCDHWIndexW];
}
} else {
if (input_shape_.size() == kDimSize5) {
inputN_ = input_shape_[kDimSize5FormatNDHWCIndexN];
inputC_ = input_shape_[kDimSize5FormatNDHWCIndexC];
inputD_ = input_shape_[kDimSize5FormatNDHWCIndexD];
inputH_ = input_shape_[kDimSize5FormatNDHWCIndexH];
inputW_ = input_shape_[kDimSize5FormatNDHWCIndexW];
} else {
inputC_ = input_shape_[kDimSize4FormatNDHWCIndexC];
inputD_ = input_shape_[kDimSize4FormatNDHWCIndexD];
inputH_ = input_shape_[kDimSize4FormatNDHWCIndexH];
inputW_ = input_shape_[kDimSize4FormatNDHWCIndexW];
}
}
int inp_flag = 0;
size_t cur_size_T = sizeof(T);
for (const auto &val : input_shape_) {
cur_size_T *= val;
}
if (cur_size_T == 0 && inp_flag == 0) {
inp_flag = 1;
}
input_size_list_.emplace_back(cur_size_T);
size_t cur_size_S = sizeof(S);
for (const auto &val : random_samples_shape_) {
cur_size_S *= val;
}
if (cur_size_S == 0 && inp_flag == 0) {
inp_flag = 1;
}
input_size_list_.emplace_back(cur_size_S);
int out_flag = 0;
cur_size_T = sizeof(T);
for (const auto &val : output_shape_) {
cur_size_T *= val;
}
cur_size_T *= inputC_;
cur_size_T *= inputN_;
if (cur_size_T == 0 && out_flag == 0) {
out_flag = 1;
}
output_size_list_.emplace_back(cur_size_T);
size_t cur_size_G = sizeof(G);
for (const auto &val : argmax_shape_) {
cur_size_G *= val;
}
if (cur_size_G == 0 && out_flag == 0) {
out_flag = 1;
}
output_size_list_.emplace_back(cur_size_G);
is_null_input_ = (inp_flag == 1 || out_flag == 1);
return CheckKernelParam();
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *input_ptr = nullptr;
S *random_samples_ptr = nullptr;
T *output_ptr = nullptr;
G *argmax_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, kInputIndex0, kernel_name_, &input_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(input_ptrs, kInputIndex1, kernel_name_, &random_samples_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, kOutputIndex0, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<G>(output_ptrs, kOutputIndex1, kernel_name_, &argmax_ptr);
if (flag != 0) {
return flag;
}
int64_t dims = static_cast<int64_t>(output_shape_.size());
int64_t outer_size = 1;
for (int64_t i = dims - 1; i >= 0; i--) {
outer_size *= output_shape_[i];
}
CalFractionalmaxpool3dwithfixedksize(input_ptr, random_samples_ptr, output_ptr, argmax_ptr, outputD_, outputH_,
outputW_, inputN_, inputC_, inputD_, inputH_, inputW_, kernelsizeD_,
kernelsizeH_, kernelsizeW_, outer_size, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<FractionalMaxPool3DWithFixedKsizeAttr>(kernel_attr);
}
protected:
int CheckKernelParam() override {
ksize_ = attr_ptr_->ksize;
output_shape_attr_ = attr_ptr_->output_shape;
// data_format_ = attr_ptr_->data_format;
kernelsizeD_ = ksize_[kkernelsizeIndexD];
kernelsizeH_ = ksize_[kkernelsizeIndexH];
kernelsizeW_ = ksize_[kkernelsizeIndexW];
outputD_ = output_shape_attr_[kOutputshapeIndexD];
outputH_ = output_shape_attr_[kOutputshapeIndexH];
outputW_ = output_shape_attr_[kOutputshapeIndexW];
size_t input_num_dims = input_shape_.size();
size_t random_samples_dims = random_samples_shape_.size();
size_t output_shape_dims = output_shape_attr_.size();
size_t ksize_dims = ksize_.size();
if (!(input_num_dims == kDimSize4 || input_num_dims == kDimSize5)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be equal to 4 or 5, but got "
<< input_num_dims << ".";
return -1;
}
if (random_samples_dims != kDimSize3) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'random_samples' must be equal to 3, but got "
<< random_samples_dims << ".";
return -1;
}
if (output_shape_dims != kDimSize3) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'output_shape' must be equal to 3, but got "
<< output_shape_dims << ".";
return -1;
}
if (ksize_dims != kDimSize3) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'ksize' must be equal to 3, but got "
<< ksize_dims << ".";
return -1;
}
if (random_samples_shape_[kDimSize2] != kDimSize3) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', expected the third dimension of 'random_samples' must be 3, but got "
<< random_samples_shape_[kDimSize2] << ".";
return -1;
}
return 0;
}
private:
std::vector<float> ksize_;
std::vector<int64_t> output_shape_attr_;
std::string data_format_;
int64_t outputD_{1};
int64_t outputH_{1};
int64_t outputW_{1};
int64_t inputN_{1};
int64_t inputC_{1};
int64_t inputD_{1};
int64_t inputH_{1};
int64_t inputW_{1};
int64_t kernelsizeD_{1};
int64_t kernelsizeH_{1};
int64_t kernelsizeW_{1};
std::shared_ptr<FractionalMaxPool3DWithFixedKsizeAttr> attr_ptr_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> random_samples_shape_;
std::vector<int64_t> output_shape_;
std::vector<int64_t> argmax_shape_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_HELPER_H_

View File

@ -0,0 +1,97 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fractionalmaxpool3dgradwithfixedksize_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T>
__global__ void InitOutput(T *output, const int64_t outer_size) {
T zero = 0;
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < outer_size; id += blockDim.x * gridDim.x) {
output[id] = zero;
}
return;
}
template <typename T, typename S>
__global__ void Fractionalmaxpool3dgradwithfixedksize(const T *origin_input, const T *out_backprop, S *argmax,
T *output, int64_t outputD, int64_t outputH, int64_t outputW,
int64_t N, int64_t C, int64_t inputD, int64_t inputH,
int64_t inputW, const int64_t out_backprop_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < out_backprop_size; pos += blockDim.x * gridDim.x) {
const int posn = pos / (C * outputD * outputH * outputW);
const int posc = pos / (outputD * outputH * outputW) % C;
S maxind = argmax[pos];
MsAtomicAdd(output + (posn * C + posc) * inputD * inputH * inputW + maxind, out_backprop[pos]);
return;
}
}
template <typename T, typename S>
void CalFractionalmaxpool3dgradwithfixedksize(const T *origin_input, const T *out_backprop, S *argmax, T *output,
int64_t outputD, int64_t outputH, int64_t outputW, int64_t inputN,
int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size,
const uint32_t &device_id, cudaStream_t cuda_stream) {
InitOutput<<<CUDA_BLOCKS(device_id, outer_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(output, outer_size);
Fractionalmaxpool3dgradwithfixedksize<<<CUDA_BLOCKS(device_id, out_backprop_size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(origin_input, out_backprop, argmax, output, outputD, outputH,
outputW, inputN, inputC, inputD, inputH, inputW,
out_backprop_size);
return;
}
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<half, int32_t>(
const half *origin_input, const half *out_backprop, int32_t *argmax, half *output, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<float, int32_t>(
const float *origin_input, const float *out_backprop, int32_t *argmax, float *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<double, int32_t>(
const double *origin_input, const double *out_backprop, int32_t *argmax, double *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<int32_t, int32_t>(
const int32_t *origin_input, const int32_t *out_backprop, int32_t *argmax, int32_t *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<int64_t, int32_t>(
const int64_t *origin_input, const int64_t *out_backprop, int32_t *argmax, int64_t *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<half, int64_t>(
const half *origin_input, const half *out_backprop, int64_t *argmax, half *output, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<float, int64_t>(
const float *origin_input, const float *out_backprop, int64_t *argmax, float *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<double, int64_t>(
const double *origin_input, const double *out_backprop, int64_t *argmax, double *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<int32_t, int64_t>(
const int32_t *origin_input, const int32_t *out_backprop, int64_t *argmax, int32_t *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize<int64_t, int64_t>(
const int64_t *origin_input, const int64_t *out_backprop, int64_t *argmax, int64_t *output, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalFractionalmaxpool3dgradwithfixedksize(const T *origin_input, const T *out_backprop, S *argmax,
T *output, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC,
int64_t inputD, int64_t inputH, int64_t inputW,
const int64_t outer_size, const int64_t out_backprop_size,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_IMPL_CUH_

View File

@ -0,0 +1,249 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/fractionalmaxpool3dwithfixedksize_impl.cuh"
#include <limits>
template <typename S>
__device__ inline int64_t get_intervals(S sample, int64_t index, int64_t inputSize, int64_t outputSize,
int64_t poolSize) {
S alpha = static_cast<S>(inputSize - poolSize) / static_cast<S>(outputSize - 1);
if (index == outputSize - 1) {
return inputSize - poolSize;
} else {
return static_cast<int64_t>((index + sample) * alpha) - static_cast<int64_t>(sample * alpha);
}
}
// half
template <>
__device__ inline int64_t get_intervals(half sample, int64_t index, int64_t inputSize, int64_t outputSize,
int64_t poolSize) {
float alpha = static_cast<float>(inputSize - poolSize) / static_cast<float>(outputSize - 1);
if (index == outputSize - 1) {
return inputSize - poolSize;
} else {
return static_cast<int64_t>((index + __half2float(sample)) * alpha) -
static_cast<int64_t>(__half2float(sample) * alpha);
}
}
template <typename T, typename S, typename G>
__global__ void Fractionalmaxpool3dwithfixedksize(const T *input, const S *random_samples, T *output, G *argmax,
int64_t outputD, int64_t outputH, int64_t outputW, int64_t N,
int64_t C, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW,
const int64_t outer_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < outer_size; pos += blockDim.x * gridDim.x) {
const int posn = pos / (C * outputD * outputH * outputW);
const int posc = pos / (outputD * outputH * outputW) % C;
const int post = pos / (outputH * outputW) % outputD;
const int posh = pos / outputW % outputH;
const int posw = pos % outputW;
int64_t poolT = get_intervals<S>(random_samples[(posn * C + posc) * 3 + 0], post, inputD, outputD, kernelsizeD);
int64_t poolH = get_intervals<S>(random_samples[(posn * C + posc) * 3 + 1], posh, inputH, outputH, kernelsizeH);
int64_t poolW = get_intervals<S>(random_samples[(posn * C + posc) * 3 + 2], posw, inputW, outputW, kernelsizeW);
int64_t maxIndex = (((posn * C + posc) * inputD + poolT) * inputH + poolH) * inputW + poolW;
T maxVal = input[maxIndex];
maxIndex = (poolT * inputH + poolH) * inputW + poolW;
for (int64_t t = poolT; t < poolT + kernelsizeD; ++t) {
for (int64_t h = poolH; h < poolH + kernelsizeH; ++h) {
for (int64_t w = poolW; w < poolW + kernelsizeW; ++w) {
int64_t index = (((posn * C + posc) * inputD + t) * inputH + h) * inputW + w;
T val = input[index];
if (val > maxVal) {
maxVal = val;
maxIndex = (t * inputH + h) * inputW + w;
}
}
}
}
argmax[pos] = static_cast<G>(maxIndex);
output[pos] = maxVal;
}
return;
}
template <typename T, typename S, typename G>
void CalFractionalmaxpool3dwithfixedksize(const T *input, const S *random_samples, T *output, G *argmax,
int64_t outputD, int64_t outputH, int64_t outputW, int64_t inputN,
int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW,
const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream) {
Fractionalmaxpool3dwithfixedksize<<<CUDA_BLOCKS(device_id, outer_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
input, random_samples, output, argmax, outputD, outputH, outputW, inputN, inputC, inputD, inputH, inputW,
kernelsizeD, kernelsizeH, kernelsizeW, outer_size);
return;
}
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, half, int32_t>(
const half *input, const half *random_samples, half *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, float, int64_t>(
const half *input, const float *random_samples, half *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, double, int64_t>(
const half *input, const double *random_samples, half *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, half, int64_t>(
const half *input, const half *random_samples, half *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, float, int32_t>(
const half *input, const float *random_samples, half *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<half, double, int32_t>(
const half *input, const double *random_samples, half *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, half, int32_t>(
const float *input, const half *random_samples, float *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, float, int64_t>(
const float *input, const float *random_samples, float *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, double, int64_t>(
const float *input, const double *random_samples, float *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, half, int64_t>(
const float *input, const half *random_samples, float *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, float, int32_t>(
const float *input, const float *random_samples, float *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<float, double, int32_t>(
const float *input, const double *random_samples, float *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, half, int32_t>(
const double *input, const half *random_samples, double *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, float, int64_t>(
const double *input, const float *random_samples, double *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, double, int64_t>(
const double *input, const double *random_samples, double *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, half, int64_t>(
const double *input, const half *random_samples, double *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, float, int32_t>(
const double *input, const float *random_samples, double *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<double, double, int32_t>(
const double *input, const double *random_samples, double *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, half, int32_t>(
const int32_t *input, const half *random_samples, int32_t *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, float, int64_t>(
const int32_t *input, const float *random_samples, int32_t *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, double, int64_t>(
const int32_t *input, const double *random_samples, int32_t *output, int64_t *argmax, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, half, int64_t>(
const int32_t *input, const half *random_samples, int32_t *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, float, int32_t>(
const int32_t *input, const float *random_samples, int32_t *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int32_t, double, int32_t>(
const int32_t *input, const double *random_samples, int32_t *output, int32_t *argmax, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, half, int32_t>(
const int64_t *input, const half *random_samples, int64_t *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, float, int64_t>(
const int64_t *input, const float *random_samples, int64_t *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, double, int64_t>(
const int64_t *input, const double *random_samples, int64_t *output, int64_t *argmax, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, half, int64_t>(
const int64_t *input, const half *random_samples, int64_t *output, int64_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, float, int32_t>(
const int64_t *input, const float *random_samples, int64_t *output, int32_t *argmax, int64_t outputD, int64_t outputH,
int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize<int64_t, double, int32_t>(
const int64_t *input, const double *random_samples, int64_t *output, int32_t *argmax, int64_t outputD,
int64_t outputH, int64_t outputW, int64_t inputN, int64_t inputC, int64_t inputD, int64_t inputH, int64_t inputW,
int64_t kernelsizeD, int64_t kernelsizeH, int64_t kernelsizeW, const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S, typename G>
CUDA_LIB_EXPORT void CalFractionalmaxpool3dwithfixedksize(const T *input, const S *random_samples, T *output, G *argmax,
int64_t outputD, int64_t outputH, int64_t outputW,
int64_t inputN, int64_t inputC, int64_t inputD,
int64_t inputH, int64_t inputW, int64_t kernelsizeD,
int64_t kernelsizeH, int64_t kernelsizeW,
const int64_t outer_size, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_IMPL_CUH_

View File

@ -0,0 +1,169 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/fractional_max_pool3d_grad_with_fixed_ksize_gpu_kernel.h"
#include <utility>
#include <iostream>
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kOriginInputIndex = 0;
constexpr size_t kOutBackpropIndex = 1;
constexpr size_t kArgmaxIndex = 2;
constexpr size_t kOutputIndex = 0;
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr(
const std::string &kernel_name, const uint32_t &device_id) {
return std::make_unique<cukernel::FractionalMaxPool3DGradWithFixedKsizeHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using FractionalMaxPool3DGradWithFixedKsizePtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, FractionalMaxPool3DGradWithFixedKsizePtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DGradWithFixedKsizeKernelPtr<int64_t, int64_t>}};
} // namespace
bool FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs,
void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::FractionalMaxPool3DGradWithFixedKsize>(base_operator);
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
attr_ptr_->data_format = kernel_ptr->get_data_format();
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
Resize(base_operator, inputs, outputs);
return true;
}
int FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod::Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> origin_input_shape = inputs[kOriginInputIndex]->GetShapeVector();
std::vector<int64_t> out_backprop_shape = inputs[kOutBackpropIndex]->GetShapeVector();
std::vector<int64_t> argmax_shape = inputs[kArgmaxIndex]->GetShapeVector();
std::vector<int64_t> out_shape = outputs[kOutputIndex]->GetShapeVector();
input_shapes.emplace_back(origin_input_shape);
input_shapes.emplace_back(out_backprop_shape);
input_shapes.emplace_back(argmax_shape);
output_shapes.emplace_back(out_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, FractionalMaxPool3DGradWithFixedKsizePtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FractionalMaxPool3DGradWithFixedKsize,
FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/grad/fractional_max_pool3d_grad_with_fixed_ksize.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/fractionalmaxpool3dgradwithfixedksize_helper.h"
namespace mindspore {
namespace kernel {
class FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod : public NativeGpuKernelMod {
public:
FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod() {
attr_ptr_ = std::make_shared<cukernel::FractionalMaxPool3DGradWithFixedKsizeAttr>();
}
~FractionalMaxPool3DGradWithFixedKsizeGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::FractionalMaxPool3DGradWithFixedKsizeAttr> attr_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DGRADWITHFIXEDKSIZE_GPU_KERNEL_H_

View File

@ -0,0 +1,291 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/nn/fractional_max_pool3d_with_fixed_ksize_gpu_kernel.h"
#include <utility>
#include <iostream>
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputIndex = 0;
constexpr size_t kRandomSamplesIndex = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kArgmaxIndex = 1;
template <typename T, typename S, typename G>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateFractionalMaxPool3DWithFixedKsizeKernelPtr(
const std::string &kernel_name, const uint32_t &device_id) {
return std::make_unique<cukernel::FractionalMaxPool3DWithFixedKsizeHelperGpuKernel<T, S, G>>(kernel_name, device_id);
}
using FractionalMaxPool3DWithFixedKsizePtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, FractionalMaxPool3DWithFixedKsizePtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<half, double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<float, double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<double, double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int32_t, double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
CreateFractionalMaxPool3DWithFixedKsizeKernelPtr<int64_t, double, int32_t>}};
} // namespace
bool FractionalMaxPool3DWithFixedKsizeGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool FractionalMaxPool3DWithFixedKsizeGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::FractionalMaxPool3DWithFixedKsize>(base_operator);
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
attr_ptr_->ksize = kernel_ptr->get_ksize();
attr_ptr_->output_shape = kernel_ptr->get_output_shape();
attr_ptr_->data_format = kernel_ptr->get_data_format();
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
Resize(base_operator, inputs, outputs);
return true;
}
int FractionalMaxPool3DWithFixedKsizeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> inp_shape = inputs[kInputIndex]->GetShapeVector();
std::vector<int64_t> random_samples_shape = inputs[kRandomSamplesIndex]->GetShapeVector();
std::vector<int64_t> out_shape = outputs[kOutputIndex]->GetShapeVector();
std::vector<int64_t> argmax_shape = outputs[kArgmaxIndex]->GetShapeVector();
input_shapes.emplace_back(inp_shape);
input_shapes.emplace_back(random_samples_shape);
output_shapes.emplace_back(out_shape);
output_shapes.emplace_back(argmax_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> FractionalMaxPool3DWithFixedKsizeGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, FractionalMaxPool3DWithFixedKsizePtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FractionalMaxPool3DWithFixedKsize,
FractionalMaxPool3DWithFixedKsizeGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/fractional_max_pool3d_with_fixed_ksize.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/fractionalmaxpool3dwithfixedksize_helper.h"
namespace mindspore {
namespace kernel {
class FractionalMaxPool3DWithFixedKsizeGpuKernelMod : public NativeGpuKernelMod {
public:
FractionalMaxPool3DWithFixedKsizeGpuKernelMod() {
attr_ptr_ = std::make_shared<cukernel::FractionalMaxPool3DWithFixedKsizeAttr>();
}
~FractionalMaxPool3DWithFixedKsizeGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::FractionalMaxPool3DWithFixedKsizeAttr> attr_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_FRACTIONALMAXPOOL3DWITHFIXEDKSIZE_GPU_KERNEL_H_

View File

@ -135,9 +135,7 @@ abstract::TupleShapePtr FractionalMaxPool3DWithFixedKsizeInferShape(const Primit
output_size.push_back(c_dim);
}
}
if (std::any_of(output_size.begin(), output_size.end(), [](int64_t shp_v) { return shp_v <= 0; })) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', output_size is not valid.";
}
if (input_shape.size() == kDimSize4) {
if (random_samples_shape[0] != input_shape[0]) {
MS_EXCEPTION(ValueError)
@ -150,11 +148,6 @@ abstract::TupleShapePtr FractionalMaxPool3DWithFixedKsizeInferShape(const Primit
<< "', if 'x' is 4 dimensional, the second dimension size of 'random_samples' must be equal to 3.";
}
} else {
if (random_samples_shape[0] != input_shape[0]) {
MS_EXCEPTION(ValueError)
<< "For '" << op_name
<< "', if 'x' is 5 dimensional, the first dimension size of 'x' and 'random_samples' must be equal.";
}
if (random_samples_shape[1] != input_shape[1]) {
MS_EXCEPTION(ValueError)
<< "For '" << op_name
@ -190,7 +183,7 @@ TuplePtr FractionalMaxPool3DWithFixedKsizeInferType(const PrimitivePtr &primitiv
}
} // namespace
MIND_API_BASE_IMPL(FractionalMaxPool3DWithFixedKsize, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(FractionalMaxPool3DWithFixedKsize, BaseOperator);
AbstractBasePtr FractionalMaxPool3DWithFixedKsizeInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -200,6 +193,39 @@ AbstractBasePtr FractionalMaxPool3DWithFixedKsizeInfer(const abstract::AnalysisE
return abstract::MakeAbstract(infer_shape, infer_type);
}
void FractionalMaxPool3DWithFixedKsize::Init(const std::vector<float> ksize, const std::vector<int64_t> output_shape,
const std::string data_format) {
set_ksize(ksize);
set_output_shape(output_shape);
set_data_format(data_format);
}
void FractionalMaxPool3DWithFixedKsize::set_ksize(const std::vector<float> ksize) {
(void)this->AddAttr("ksize", api::MakeValue(ksize));
}
void FractionalMaxPool3DWithFixedKsize::set_output_shape(const std::vector<int64_t> output_shape) {
(void)this->AddAttr("output_shape", api::MakeValue(output_shape));
}
void FractionalMaxPool3DWithFixedKsize::set_data_format(const std::string data_format) {
(void)this->AddAttr(kFormat, api::MakeValue(data_format));
}
std::vector<float> FractionalMaxPool3DWithFixedKsize::get_ksize() const {
auto value_ptr = GetAttr("ksize");
return GetValue<std::vector<float>>(value_ptr);
}
std::vector<int64_t> FractionalMaxPool3DWithFixedKsize::get_output_shape() const {
auto value_ptr = GetAttr("output_shape");
return GetValue<std::vector<int64_t>>(value_ptr);
}
std::string FractionalMaxPool3DWithFixedKsize::get_data_format() const {
return GetValue<std::string>(GetAttr(kFormat));
}
REGISTER_PRIMITIVE_EVAL_IMPL(FractionalMaxPool3DWithFixedKsize, prim::kPrimFractionalMaxPool3DWithFixedKsize,
FractionalMaxPool3DWithFixedKsizeInfer, nullptr, true);
} // namespace ops

View File

@ -33,6 +33,27 @@ class MIND_API FractionalMaxPool3DWithFixedKsize : public BaseOperator {
FractionalMaxPool3DWithFixedKsize() : BaseOperator(kNameFractionalMaxPool3DWithFixedKsize) {
InitIOName({"x", "random_samples"}, {"y", "argmax"});
}
void Init(const std::vector<float> ksize, const std::vector<int64_t> output_shape, const std::string data_format);
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.FractionalMaxPool3DWithFixedKsize for the
/// inputs.
void set_ksize(const std::vector<float> ksize);
/// \brief Set ksize.
void set_output_shape(const std::vector<int64_t> output_shape);
/// \brief Set output shape.
void set_data_format(const std::string data_format);
/// \brief Set data format.
std::vector<float> get_ksize() const;
/// \brief Method to get ksize attributes.
///
/// \return ksize attributes.
std::vector<int64_t> get_output_shape() const;
/// \brief Method to get output shape attributes.
///
/// \return output shape attributes.
std::string get_data_format() const;
/// \brief Method to get data format attributes.
///
/// \return data format attributes.
};
abstract::AbstractBasePtr FractionalMaxPool3DWithFixedKsizeInfer(

View File

@ -134,9 +134,6 @@ abstract::ShapePtr FractionalMaxPool3DGradWithFixedKsizeInferShape(const Primiti
output_size.push_back(c_dim_);
}
}
if (std::any_of(output_size.begin(), output_size.end(), [](int64_t shp_v) { return shp_v <= 0; })) {
MS_LOG(EXCEPTION) << "For '" << op_name << "', output_size is not valid.";
}
return std::make_shared<abstract::Shape>(output_size);
}
@ -162,7 +159,7 @@ TypePtr FractionalMaxPool3DGradWithFixedKsizeInferType(const PrimitivePtr &primi
}
} // namespace
MIND_API_BASE_IMPL(FractionalMaxPool3DGradWithFixedKsize, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(FractionalMaxPool3DGradWithFixedKsize, BaseOperator);
AbstractBasePtr FractionalMaxPool3DGradWithFixedKsizeInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -172,6 +169,16 @@ AbstractBasePtr FractionalMaxPool3DGradWithFixedKsizeInfer(const abstract::Analy
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape);
}
void FractionalMaxPool3DGradWithFixedKsize::Init(const std::string data_format) { set_data_format(data_format); }
void FractionalMaxPool3DGradWithFixedKsize::set_data_format(const std::string data_format) {
(void)this->AddAttr(kFormat, api::MakeValue(data_format));
}
std::string FractionalMaxPool3DGradWithFixedKsize::get_data_format() const {
return GetValue<std::string>(GetAttr(kFormat));
}
REGISTER_PRIMITIVE_EVAL_IMPL(FractionalMaxPool3DGradWithFixedKsize, prim::kPrimFractionalMaxPool3DGradWithFixedKsize,
FractionalMaxPool3DGradWithFixedKsizeInfer, nullptr, true);
} // namespace ops

View File

@ -33,6 +33,15 @@ class MIND_API FractionalMaxPool3DGradWithFixedKsize : public BaseOperator {
FractionalMaxPool3DGradWithFixedKsize() : BaseOperator(kNameFractionalMaxPool3DGradWithFixedKsize) {
InitIOName({"origin_input", "out_backprop", "argmax"}, {"y"});
}
void Init(const std::string data_format);
/// \brief Init. Refer to the parameters of Python API @ref
/// mindspore.ops.operations._grad_ops.FractionalMaxPool3DWithFixedKsize for the inputs.
void set_data_format(const std::string data_format);
/// \brief Set data format.
std::string get_data_format() const;
/// \brief Method to get data format attributes.
///
/// \return data format attributes.
};
abstract::AbstractBasePtr FractionalMaxPool3DGradWithFixedKsizeInfer(

View File

@ -9755,7 +9755,7 @@ class FractionalMaxPool3DWithFixedKsize(Primitive):
ValueError: If the third dimension size of `random_samples` is not 3.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])

View File

@ -0,0 +1,81 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations.nn_ops as ops
import mindspore.ops.operations._grad_ops as grad_ops
class NetFractionalMaxPool3DWithFixedKsize(nn.Cell):
def __init__(self, ksize, output_shape):
super(NetFractionalMaxPool3DWithFixedKsize, self).__init__()
self.fractional_max_pool_3d_with_fixed_ksize = ops.FractionalMaxPool3DWithFixedKsize(ksize, output_shape)
def construct(self, x, random_sapmples):
return self.fractional_max_pool_3d_with_fixed_ksize(x, random_sapmples)
class NetFractionalMaxPool3DGradWithFixedKsize(nn.Cell):
def __init__(self):
super(NetFractionalMaxPool3DGradWithFixedKsize, self).__init__()
self.fractional_max_pool_3d_grad_with_fixed_ksize = grad_ops.FractionalMaxPool3DGradWithFixedKsize()
def construct(self, origin_input, out_backprop, argmax):
return self.fractional_max_pool_3d_grad_with_fixed_ksize(origin_input, out_backprop, argmax)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fractionalmaxpool3dwithfixedksize():
"""
Feature: FractionalMaxPool3DWithFixedKsize
Description: Test of input
Expectation: The results are as expected
"""
context_mode_types = [context.GRAPH_MODE, context.PYNATIVE_MODE]
types_input1 = [np.float16, np.float32, np.int32, np.int64]
types_input2 = [np.float16, np.float32]
for context_mode_type in context_mode_types:
context.set_context(mode=context_mode_type, device_target='GPU')
for type_input1 in types_input1:
for type_input2 in types_input2:
x_np = np.array([i+1 for i in range(64)]).reshape([1, 1, 4, 4, 4]).astype(type_input1)
x_ms = Tensor(x_np)
random_samples = Tensor(np.array([0.5, 0.5, 0.8]).reshape([1, 1, 3]).astype(type_input2))
ksize = (1.0, 1.0, 1.0)
output_shape = (2, 2, 3)
net = NetFractionalMaxPool3DWithFixedKsize(ksize, output_shape)
output_ms, argmax = net(x_ms, random_samples)
expect_output = np.array([[[[[1, 2, 4], [13, 14, 16]],
[[49, 50, 52], [61, 62, 64]]]]]).astype(type_input1)
expect_output_argmax = np.array([[[[[0, 1, 3], [12, 13, 15]],
[[48, 49, 51], [60, 61, 63]]]]]).astype(type_input2)
assert np.allclose(output_ms.asnumpy(), expect_output)
assert np.allclose(argmax.asnumpy(), expect_output_argmax)
out_backprop = Tensor(np.array([i+1 for i in range(12)]).reshape([1, 1, 2, 2, 3]).astype(type_input1))
net_grad = NetFractionalMaxPool3DGradWithFixedKsize()
output_grad = net_grad(x_ms, out_backprop, argmax)
expect_output_grad = np.array([[[[[1, 2, 0, 3], [0, 0, 0, 0], [0, 0, 0, 0], [4, 5, 0, 6]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[7, 8, 0, 9], [0, 0, 0, 0], [0, 0, 0, 0],
[10, 11, 0, 12]]]]]).astype(type_input2)
assert np.allclose(output_grad.asnumpy(), expect_output_grad)