!48119 add circular mode for Padv3/Padv3Grad on CPU/GPU.

Merge pull request !48119 from yangshuo/br_padv3
This commit is contained in:
i-robot 2023-01-31 07:25:49 +00:00 committed by Gitee
commit ed6dcb4598
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 753 additions and 69 deletions

View File

@ -6,7 +6,7 @@ mindspore.ops.PadV3
根据参数 `mode``paddings_contiguous` 对输入进行填充。
参数:
- **mode** (str可选) - 填充模式,支持"constant" 、"reflect" 和 "edge"。默认值:"constant"。
- **mode** (str可选) - 填充模式,支持"constant" 、"reflect"、"edge" 和 "circular"。默认值:"constant"。
- **paddings_contiguous** (bool可选) - 是否连续填充。如果为True `paddings` 格式为[begin0, end0, begin1, end1, ...]如果为False`paddings` 格式为[begin0, begin1, ..., end1, end2, ...]。默认值True。
输入:
@ -23,11 +23,12 @@ mindspore.ops.PadV3
- **ValueError** - `mode` 不是string类型或者不在支持的列表里。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数不是偶数。
- **ValueError** - `mode` 是"constant"的同时 `paddings` 元素个数大于输入维度乘以2。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `paddings` 元素个数不是2、4或6。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是3 `paddings` 元素个数不是2。
- **ValueError** - `mode` 是"edge"或"reflect" `x` 的维度是4 `paddings` 元素个数不是4。
- **ValueError** - `mode` 是"edge"或"reflect"的同时 `x` 的维度小于3。
- **ValueError** - `mode` 是"edge"的同时 `x` 的维度大于5。
- **ValueError** - `mode` 是"edge"、"reflect"或"circular"的时 `paddings` 元素个数不是2、4或6。
- **ValueError** - `mode` 是"edge"、"reflect"或"circular" `x` 的维度是3 `paddings` 元素个数不是2。
- **ValueError** - `mode` 是"edge"、"reflect"或"circular" `x` 的维度是4 `paddings` 元素个数不是4。
- **ValueError** - `mode` 是"circular" `x` 的维度是5 `paddings` 元素个数不是6。
- **ValueError** - `mode` 是"edge"、"reflect"或"circular"的同时 `x` 的维度小于3。
- **ValueError** - `mode` 是"edge"或"circular"的时 `x` 的维度大于5。
- **ValueError** - `mode` 是"reflect"的同时 `x` 的维度大于4。
- **ValueError** - `mode` 是"reflect"的同时填充值大于对应 `x` 的维度。
- **ValueError** - 填充之后输出shape数不大于零。

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/pad_v3_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/op_name.h"
#include "mindspore/core/ops/pad_v3.h"
namespace mindspore {
@ -25,16 +26,13 @@ constexpr auto kPadV3 = "PadV3";
constexpr const size_t kConstantInputsNum = 3;
constexpr const size_t kOtherInputsNum = 2;
constexpr const size_t kOutputsNum = 1;
constexpr int64_t kInput3D = 3;
constexpr int64_t kInput4D = 4;
constexpr int64_t kInput5D = 5;
constexpr int64_t kPadding1D = 2;
constexpr int64_t kPadding2D = 4;
constexpr int64_t kPadding3D = 6;
constexpr int64_t kNum2 = 2;
constexpr int64_t kNum3 = 3;
constexpr int64_t kNum4 = 4;
const std::vector<std::string> mode_list = {"constant", "reflect", "edge"};
const std::vector<std::string> mode_list = {ops::kConstant, ops::kReflect, ops::kEdge, ops::kCircular};
} // namespace
bool PadV3CpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -208,11 +206,12 @@ void PadV3CpuKernelMod::OtherModeCompute1D(T *input_ptr, T *output_ptr, int64_t
int64_t nplane = 0;
int64_t input_w = input_shape_[kNum2];
int64_t output_w = output_shape_.end()[-1];
int64_t pad_l = paddings_[0];
int64_t pad_l = paddings_[kIndex0];
int64_t pad_r = paddings_[kIndex1];
int64_t i_start_x = std::max(int64_t(0), -pad_l);
int64_t o_start_x = std::max(int64_t(0), pad_l);
for (int64_t j = 0; j < output_w; ++j) {
auto ip_x = IndexCalculate(pad_l, j, input_w, o_start_x, i_start_x);
auto ip_x = IndexCalculate(pad_l, pad_r, j, input_w, o_start_x, i_start_x);
T *dest_p = output_ptr + p * output_w * (nplane + 1) + j;
T *src_p = input_ptr + +p * input_w * (nplane + 1) + ip_x;
*dest_p = *src_p;
@ -221,21 +220,23 @@ void PadV3CpuKernelMod::OtherModeCompute1D(T *input_ptr, T *output_ptr, int64_t
template <typename T>
void PadV3CpuKernelMod::OtherModeCompute2D(T *input_ptr, T *output_ptr, int64_t p) const {
int64_t pad_l = paddings_[0];
int64_t pad_t = paddings_[kNum2];
int64_t pad_l = paddings_[kIndex0];
int64_t pad_r = paddings_[kIndex1];
int64_t pad_t = paddings_[kIndex2];
int64_t pad_d = paddings_[kIndex3];
int64_t nplane = 0;
int64_t input_h = input_shape_[kNum2];
int64_t input_w = input_shape_[kNum3];
int64_t output_h = input_h + pad_t + paddings_[kNum3];
int64_t output_w = input_w + pad_l + paddings_[1];
int64_t input_h = input_shape_[kIndex2];
int64_t input_w = input_shape_[kIndex3];
int64_t output_h = input_h + pad_t + paddings_[kIndex3];
int64_t output_w = input_w + pad_l + paddings_[kIndex1];
int64_t i_start_x = std::max(int64_t(0), -pad_l);
int64_t i_start_y = std::max(int64_t(0), -pad_t);
int64_t o_start_x = std::max(int64_t(0), pad_l);
int64_t o_start_y = std::max(int64_t(0), pad_t);
for (int64_t i = 0; i < output_h; ++i) {
for (int64_t j = 0; j < output_w; ++j) {
auto ip_x = IndexCalculate(pad_l, j, input_w, o_start_x, i_start_x);
auto ip_y = IndexCalculate(pad_t, i, input_h, o_start_y, i_start_y);
auto ip_x = IndexCalculate(pad_l, pad_r, j, input_w, o_start_x, i_start_x);
auto ip_y = IndexCalculate(pad_t, pad_d, i, input_h, o_start_y, i_start_y);
T *dest_p = output_ptr + p * output_w * output_h * (nplane + 1) + i * output_w + j;
T *src_p = input_ptr + p * input_w * input_h * (nplane + 1) + ip_y * input_w + ip_x;
*dest_p = *src_p;
@ -245,9 +246,12 @@ void PadV3CpuKernelMod::OtherModeCompute2D(T *input_ptr, T *output_ptr, int64_t
template <typename T>
void PadV3CpuKernelMod::OtherModeCompute3D(T *input_ptr, T *output_ptr, int64_t p) const {
int64_t pad_l = paddings_[0];
int64_t pad_t = paddings_[kNum2];
int64_t pad_f = paddings_[kNum4];
int64_t pad_l = paddings_[kIndex0];
int64_t pad_r = paddings_[kIndex1];
int64_t pad_t = paddings_[kIndex2];
int64_t pad_d = paddings_[kIndex3];
int64_t pad_f = paddings_[kIndex4];
int64_t pad_b = paddings_[kIndex5];
int64_t nplane = 0;
int64_t input_d = input_shape_[kNum2];
int64_t input_h = input_shape_[kNum3];
@ -264,9 +268,9 @@ void PadV3CpuKernelMod::OtherModeCompute3D(T *input_ptr, T *output_ptr, int64_t
for (int64_t k = 0; k < output_d; ++k) {
for (int64_t j = 0; j < output_h; ++j) {
for (int64_t i = 0; i < output_w; ++i) {
auto ip_x = IndexCalculate(pad_l, i, input_w, o_start_x, i_start_x);
auto ip_y = IndexCalculate(pad_t, j, input_h, o_start_y, i_start_y);
auto ip_z = IndexCalculate(pad_f, k, input_d, o_start_z, i_start_z);
auto ip_x = IndexCalculate(pad_l, pad_r, i, input_w, o_start_x, i_start_x);
auto ip_y = IndexCalculate(pad_t, pad_d, j, input_h, o_start_y, i_start_y);
auto ip_z = IndexCalculate(pad_f, pad_b, k, input_d, o_start_z, i_start_z);
T *dest_p =
output_ptr + p * output_w * output_h * output_d * (nplane + 1) + k * output_w * output_h + j * output_w + i;
T *src_p =
@ -277,22 +281,26 @@ void PadV3CpuKernelMod::OtherModeCompute3D(T *input_ptr, T *output_ptr, int64_t
}
}
int64_t PadV3CpuKernelMod::IndexCalculate(int64_t pad_value, int64_t now, int64_t input_value, int64_t o_start,
int64_t i_start) const {
int64_t PadV3CpuKernelMod::IndexCalculate(int64_t pad_value, int64_t pad_end, int64_t now, int64_t input_value,
int64_t o_start, int64_t i_start) const {
int64_t ip = 0;
if (now < pad_value) {
if (mode_ == "reflect") {
if (mode_ == ops::kReflect) {
ip = pad_value + pad_value - now;
} else if (mode_ == "edge") {
} else if (mode_ == ops::kEdge) {
ip = pad_value;
} else if (mode_ == ops::kCircular) {
ip = input_value + now + std::min(int64_t(0), pad_end);
}
} else if (now >= pad_value && now < input_value + pad_value) {
ip = now;
} else {
if (mode_ == "reflect") {
if (mode_ == ops::kReflect) {
ip = (input_value + pad_value - 1) + (input_value + pad_value - 1) - now;
} else if (mode_ == "edge") {
} else if (mode_ == ops::kEdge) {
ip = input_value + pad_value - 1;
} else if (mode_ == ops::kCircular) {
ip = now - input_value - std::min(int64_t(0), pad_value);
}
}
ip = ip - o_start + i_start;
@ -307,7 +315,7 @@ bool PadV3CpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, cons
}
auto input_ptr = static_cast<T *>(inputs[0]->addr);
auto output_ptr = static_cast<T *>(outputs[0]->addr);
if (mode_ == "constant") {
if (mode_ == ops::kConstant) {
T constant_values = *(static_cast<T *>(inputs[2]->addr));
for (int64_t i = 0; i < input_dim_ / kNum2; ++i) {
int64_t u = paddings_[i * kNum2];

View File

@ -77,7 +77,8 @@ class PadV3CpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<Pa
template <typename T>
void OtherModeCompute3D(T *input_ptr, T *output_ptr, int64_t p) const;
int64_t IndexCalculate(int64_t pad_value, int64_t now, int64_t input_value, int64_t o_start, int64_t i_start) const;
int64_t IndexCalculate(int64_t pad_value, int64_t pad_end, int64_t now, int64_t input_value, int64_t o_start,
int64_t i_start) const;
bool paddings_contiguous_;
int64_t parallelSliceNum_{1};

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/pad_v3_grad_cpu_kernel.h"
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/op_name.h"
#include "mindspore/core/ops/grad/pad_v3_grad.h"
namespace mindspore {
@ -30,6 +31,9 @@ constexpr int64_t k1DNum = 2;
constexpr int64_t kpad_l = 0;
constexpr int64_t kpad_t = 2;
constexpr int64_t kpad_f = 4;
constexpr int64_t kpad_r = 1;
constexpr int64_t kpad_d = 3;
constexpr int64_t kpad_b = 5;
constexpr int64_t kwidth = 1;
constexpr int64_t kheight = 2;
constexpr int64_t kchannel = 3;
@ -40,7 +44,7 @@ constexpr int64_t k2Num = 2;
constexpr int64_t padding_pos_2 = 2;
constexpr int64_t padding_pos_3 = 3;
constexpr int64_t padding_pos_4 = 4;
const std::vector<std::string> mode_list = {"reflect", "edge"};
const std::vector<std::string> mode_list = {ops::kReflect, ops::kEdge, ops::kCircular};
} // namespace
bool PadV3GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -55,7 +59,7 @@ bool PadV3GradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
mode_ = kernel_ptr->get_mode();
const bool is_mode_available = std::find(mode_list.begin(), mode_list.end(), mode_) != mode_list.end();
if (is_mode_available == false) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'mode' should be 'constant', 'reflect' or 'edge', but got "
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the 'mode' should be 'reflect', 'edge' or 'circular', but got "
<< mode_;
return false;
}
@ -148,7 +152,7 @@ void PadV3GradCpuKernelMod::PadV3GradCompute(T *input, T *output, int64_t p) con
template <typename T>
void PadV3GradCpuKernelMod::PadV3GradCompute1D(T *input, T *output, int64_t p) const {
for (int j = 0; j < input_w_; j++) {
auto ip_x = IndexCalculate(pad_l_, j, output_w_, o_start_x_, i_start_x_);
auto ip_x = IndexCalculate(pad_l_, pad_r_, j, output_w_, o_start_x_, i_start_x_);
T *src_p = input + p * input_w_ + j;
T *dest_p = output + p * output_w_ + ip_x;
*dest_p += *src_p;
@ -158,8 +162,8 @@ void PadV3GradCpuKernelMod::PadV3GradCompute1D(T *input, T *output, int64_t p) c
template <typename T>
void PadV3GradCpuKernelMod::PadV3GradCompute2D(T *input, T *output, int64_t p, int64_t i) const {
for (int j = 0; j < input_w_; j++) {
auto ip_x = IndexCalculate(pad_l_, j, output_w_, o_start_x_, i_start_x_);
auto ip_y = IndexCalculate(pad_t_, i, output_h_, o_start_y_, i_start_y_);
auto ip_x = IndexCalculate(pad_l_, pad_r_, j, output_w_, o_start_x_, i_start_x_);
auto ip_y = IndexCalculate(pad_t_, pad_d_, i, output_h_, o_start_y_, i_start_y_);
T *src_p = input + p * input_w_ * input_h_ + i * input_w_ + j;
T *dest_p = output + p * output_w_ * output_h_ + ip_y * output_w_ + ip_x;
*dest_p += *src_p;
@ -170,9 +174,9 @@ template <typename T>
void PadV3GradCpuKernelMod::PadV3GradCompute3D(T *input, T *output, int64_t p, int64_t z) const {
for (int i = 0; i < input_h_; i++) {
for (int j = 0; j < input_w_; j++) {
auto ip_x = IndexCalculate(pad_l_, j, output_w_, o_start_x_, i_start_x_);
auto ip_y = IndexCalculate(pad_t_, i, output_h_, o_start_y_, i_start_y_);
auto ip_z = IndexCalculate(pad_f_, z, output_c_, o_start_z_, i_start_z_);
auto ip_x = IndexCalculate(pad_l_, pad_r_, j, output_w_, o_start_x_, i_start_x_);
auto ip_y = IndexCalculate(pad_t_, pad_d_, i, output_h_, o_start_y_, i_start_y_);
auto ip_z = IndexCalculate(pad_f_, pad_b_, z, output_c_, o_start_z_, i_start_z_);
T *src_p = input + p * input_w_ * input_h_ * input_c_ + z * input_w_ * input_h_ + i * input_w_ + j;
T *dest_p =
output + p * output_w_ * output_h_ * output_c_ + ip_z * output_w_ * output_h_ + ip_y * output_w_ + ip_x;
@ -181,22 +185,26 @@ void PadV3GradCpuKernelMod::PadV3GradCompute3D(T *input, T *output, int64_t p, i
}
}
int64_t PadV3GradCpuKernelMod::IndexCalculate(int64_t pad_value, int64_t now, int64_t output_value, int64_t o_start,
int64_t i_start) const {
int64_t PadV3GradCpuKernelMod::IndexCalculate(int64_t pad_value, int64_t pad_end, int64_t now, int64_t output_value,
int64_t o_start, int64_t i_start) const {
int64_t ip = 0;
if (now < pad_value) {
if (mode_ == "reflect") {
if (mode_ == ops::kReflect) {
ip = pad_value + pad_value - now;
} else if (mode_ == "edge") {
} else if (mode_ == ops::kEdge) {
ip = pad_value;
} else if (mode_ == ops::kCircular) {
ip = output_value + now + std::min(int64_t(0), pad_end);
}
} else if (now >= pad_value && now < output_value + pad_value) {
ip = now;
} else {
if (mode_ == "reflect") {
if (mode_ == ops::kReflect) {
ip = (output_value + pad_value - 1) + (output_value + pad_value - 1) - now;
} else if (mode_ == "edge") {
} else if (mode_ == ops::kEdge) {
ip = output_value + pad_value - 1;
} else if (mode_ == ops::kCircular) {
ip = now - output_value - std::min(int64_t(0), pad_value);
}
}
ip = ip - o_start + i_start;
@ -209,7 +217,6 @@ bool PadV3GradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
if (!GetPaddings<S>(inputs)) {
MS_LOG(EXCEPTION) << "get paddings failed";
}
output_w_ = output_shape_.end()[-kwidth];
output_h_ = output_shape_.end()[-kheight];
output_c_ = output_shape_.end()[-kchannel];
@ -227,6 +234,9 @@ bool PadV3GradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
pad_l_ = paddings_[kpad_l];
pad_t_ = paddings_[kpad_t];
pad_f_ = paddings_[kpad_f];
pad_r_ = paddings_[kpad_r];
pad_d_ = paddings_[kpad_d];
pad_b_ = paddings_[kpad_b];
int64_t output_num_ = 1;
for (int64_t i = 0; i < input_dim_; i++) {

View File

@ -73,7 +73,8 @@ class PadV3GradCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
void PadV3GradCompute3D(T *input, T *output, int64_t p, int64_t z) const;
int64_t IndexCalculate(int64_t pad_value, int64_t now, int64_t output_value, int64_t o_start, int64_t i_start) const;
int64_t IndexCalculate(int64_t pad_value, int64_t pad_end, int64_t now, int64_t output_value, int64_t o_start,
int64_t i_start) const;
using SelectFunc =
std::function<bool(PadV3GradCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
@ -103,6 +104,9 @@ class PadV3GradCpuKernelMod : public NativeCpuKernelMod {
int64_t pad_l_{0};
int64_t pad_t_{0};
int64_t pad_f_{0};
int64_t pad_r_{0};
int64_t pad_d_{0};
int64_t pad_b_{0};
int64_t parallelSliceNum_{1};
int64_t paddings_num_{0};
int64_t input_dim_{0};

View File

@ -104,7 +104,6 @@ class PadV3GradHelperGpuKernel : public GpuKernelHelperBase {
if (flag != 0) {
return flag;
}
// call cuda kernel
if (mode_ == ops::kReflect) {
CalReflectPadGrad3d(input_size_, input_ptr, input_shape_5d_[0], input_shape_5d_[1], input_shape_5d_[2],
@ -116,6 +115,13 @@ class PadV3GradHelperGpuKernel : public GpuKernelHelperBase {
input_shape_5d_[3], input_shape_5d_[4], output_shape_5d_[2], output_shape_5d_[3],
output_shape_5d_[4], paddings_3d_[0].first, paddings_3d_[1].first, paddings_3d_[2].first,
output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
} else if (mode_ == ops::kCircular) {
CalCircularPadGrad3d(input_size_, input_ptr, input_shape_5d_[kIndex2], input_shape_5d_[kIndex3],
input_shape_5d_[kIndex4], output_shape_5d_[kIndex2], output_shape_5d_[kIndex3],
output_shape_5d_[kIndex4], paddings_3d_[kIndex0].first, paddings_3d_[kIndex1].first,
paddings_3d_[kIndex2].first, paddings_3d_[kIndex0].second, paddings_3d_[kIndex1].second,
paddings_3d_[kIndex2].second, output_ptr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
}
return 0;
}

View File

@ -132,6 +132,13 @@ class PadV3HelperGpuKernel : public GpuKernelHelperBase {
input_shape_5d_[3], input_shape_5d_[4], output_shape_5d_[2], output_shape_5d_[3],
output_shape_5d_[4], paddings_3d_[0].first, paddings_3d_[1].first, paddings_3d_[2].first, output_ptr,
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
} else if (mode_ == ops::kCircular) {
CalCircularPad3d(output_size_, input_ptr, input_shape_5d_[kIndex2], input_shape_5d_[kIndex3],
input_shape_5d_[kIndex4], output_shape_5d_[kIndex2], output_shape_5d_[kIndex3],
output_shape_5d_[kIndex4], paddings_3d_[kIndex0].first, paddings_3d_[kIndex1].first,
paddings_3d_[kIndex2].first, paddings_3d_[kIndex0].second, paddings_3d_[kIndex1].second,
paddings_3d_[kIndex2].second, output_ptr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
}
return 0;
}

View File

@ -68,6 +68,87 @@ __global__ void ConstantPadGrad3d(const size_t size, const T *dy, const int64_t
}
}
template <typename T>
__global__ void CircularPad3d(const size_t size, const T *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top,
const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int64_t nc = pos / padded_width;
const int64_t out_w = pos % padded_width;
const int64_t out_h = nc % padded_height;
nc /= padded_height;
const int64_t out_d = nc % padded_depth;
nc /= padded_depth;
int in_d = ((out_d - pad_head) % old_depth + old_depth) % old_depth;
int in_h = ((out_h - pad_top) % old_height + old_height) % old_height;
int in_w = ((out_w - pad_left) % old_width + old_width) % old_width;
if (out_d < pad_head) {
in_d = (in_d + imin(0, pad_back) + old_depth) % old_depth;
}
if (out_d >= old_depth + pad_head) {
in_d = (in_d + imax(0, -pad_head) + old_depth) % old_depth;
}
if (out_h < pad_top) {
in_h = (in_h + imin(0, pad_down) + old_height) % old_height;
}
if (out_h >= old_height + pad_top) {
in_h = (in_h + imax(0, -pad_top) + old_height) % old_height;
}
if (out_w < pad_left) {
in_w = (in_w + imin(0, pad_right) + old_width) % old_width;
}
if (out_w >= old_width + pad_left) {
in_w = (in_w + imax(0, -pad_left) + old_width) % old_width;
}
output[pos] = input[(nc * old_depth * old_height + in_d * old_height + in_h) * old_width + in_w];
}
}
template <typename T>
__global__ void CircularPadGrad3d(const size_t size, const T *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top,
const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int nc = pos / old_width;
const int out_w = pos % old_width;
const int out_h = nc % old_height;
nc /= old_height;
const int out_d = nc % old_depth;
nc /= old_depth;
int in_d = ((out_d - pad_head) % padded_depth + padded_depth) % padded_depth;
int in_h = ((out_h - pad_top) % padded_height + padded_height) % padded_height;
int in_w = ((out_w - pad_left) % padded_width + padded_width) % padded_width;
if (out_d < pad_head) {
in_d = (in_d + imin(0, pad_back) + padded_depth) % padded_depth;
}
if (out_d >= padded_depth + pad_head) {
in_d = (in_d + imax(0, -pad_head) + padded_depth) % padded_depth;
}
if (out_h < pad_top) {
in_h = (in_h + imin(0, pad_down) + padded_height) % padded_height;
}
if (out_h >= padded_height + pad_top) {
in_h = (in_h + imax(0, -pad_top) + padded_height) % padded_height;
}
if (out_w < pad_left) {
in_w = (in_w + imin(0, pad_right) + padded_width) % padded_width;
}
if (out_w >= padded_width + pad_left) {
in_w = (in_w + imax(0, -pad_left) + padded_width) % padded_width;
}
int index = (nc * padded_depth * padded_height + in_d * padded_height + in_h) * padded_width + in_w;
MsAtomicAdd(&output[index], input[pos]);
}
}
template <typename T>
__global__ void ReflectPad3d(const size_t size, const T *input, const int64_t num, const int64_t channels,
const int64_t old_depth, const int64_t old_height, const int64_t old_width,
@ -221,6 +302,28 @@ void CalConstantPadGrad3d(const size_t size, const T *dy, const int64_t num, con
padded_width, padded_dhw, padded_hw, pad_head, pad_top, pad_left, dx);
}
template <typename T>
void CalCircularPad3d(const size_t size, const T *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int64_t pad_back, const int64_t pad_down, const int64_t pad_right, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream) {
CircularPad3d<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, input, old_depth, old_height, old_width, padded_depth, padded_height, padded_width, pad_head, pad_top,
pad_left, pad_back, pad_down, pad_right, output);
}
template <typename T>
void CalCircularPadGrad3d(const size_t size, const T *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top,
const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
CircularPadGrad3d<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size, input, old_depth, old_height, old_width, padded_depth, padded_height, padded_width, pad_head, pad_top,
pad_left, pad_back, pad_down, pad_right, output);
}
template <typename T>
void CalReflectPad3d(const size_t size, const T *input, const int64_t num, const int64_t channels,
const int64_t old_depth, const int64_t old_height, const int64_t old_width,
@ -286,6 +389,13 @@ template CUDA_LIB_EXPORT void CalConstantPad3d<double>(
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const double *pad_value, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<double>(
const size_t size, const double *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<double>(const size_t size, const double *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -316,11 +426,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<double>(const size_t size, double
const int64_t pad_top, const int64_t pad_left, double *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<double>(
const size_t size, const double *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<float>(
const size_t size, const float *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const float *pad_value, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<float>(
const size_t size, const float *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<float>(const size_t size, const float *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -351,11 +474,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<float>(const size_t size, float *
const int64_t pad_top, const int64_t pad_left, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<float>(
const size_t size, const float *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<half>(
const size_t size, const half *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const half *pad_value, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<half>(
const size_t size, const half *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<half>(const size_t size, const half *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -386,11 +522,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<half>(const size_t size, half *in
const int64_t pad_top, const int64_t pad_left, half *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<half>(
const size_t size, const half *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<int64_t>(
const size_t size, const int64_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int64_t *pad_value, int64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<int64_t>(
const size_t size, const int64_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<int64_t>(const size_t size, const int64_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -419,11 +568,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<int64_t>(const size_t size, int64
const int64_t pad_top, const int64_t pad_left, int64_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<int64_t>(
const size_t size, const int64_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<int32_t>(
const size_t size, const int32_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int32_t *pad_value, int32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<int32_t>(
const size_t size, const int32_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<int32_t>(const size_t size, const int32_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -452,11 +614,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<int32_t>(const size_t size, int32
const int64_t pad_top, const int64_t pad_left, int32_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<int32_t>(
const size_t size, const int32_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<int16_t>(
const size_t size, const int16_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int16_t *pad_value, int16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<int16_t>(
const size_t size, const int16_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<int16_t>(const size_t size, const int16_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -485,11 +660,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<int16_t>(const size_t size, int16
const int64_t pad_top, const int64_t pad_left, int16_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<int16_t>(
const size_t size, const int16_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<int8_t>(
const size_t size, const int8_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int8_t *pad_value, int8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<int8_t>(
const size_t size, const int8_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<int8_t>(const size_t size, const int8_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -520,11 +708,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<int8_t>(const size_t size, int8_t
const int64_t pad_top, const int64_t pad_left, int8_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<int8_t>(
const size_t size, const int8_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, int8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<uint64_t>(
const size_t size, const uint64_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const uint64_t *pad_value, uint64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<uint64_t>(
const size_t size, const uint64_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<uint64_t>(const size_t size, const uint64_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -551,11 +752,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<uint64_t>(
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, uint64_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<uint64_t>(
const size_t size, const uint64_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint64_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<uint32_t>(
const size_t size, const uint32_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const uint32_t *pad_value, uint32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<uint32_t>(
const size_t size, const uint32_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<uint32_t>(const size_t size, const uint32_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -582,11 +796,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<uint32_t>(
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, uint32_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<uint32_t>(
const size_t size, const uint32_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint32_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<uint16_t>(
const size_t size, const uint16_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const uint16_t *pad_value, uint16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<uint16_t>(
const size_t size, const uint16_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<uint16_t>(const size_t size, const uint16_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -613,11 +840,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<uint16_t>(
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, uint16_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<uint16_t>(
const size_t size, const uint16_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint16_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<uint8_t>(
const size_t size, const uint8_t *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const uint8_t *pad_value, uint8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<uint8_t>(
const size_t size, const uint8_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<uint8_t>(const size_t size, const uint8_t *input, const int64_t num,
const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width,
@ -646,11 +886,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<uint8_t>(const size_t size, uint8
const int64_t pad_top, const int64_t pad_left, uint8_t *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<uint8_t>(
const size_t size, const uint8_t *input, const int64_t old_depth, const int64_t old_height, const int64_t old_width,
const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, uint8_t *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<Complex<float>>(
const size_t size, const Complex<float> *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const Complex<float> *pad_value, Complex<float> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<Complex<float>>(
const size_t size, const Complex<float> *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width,
const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, Complex<float> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<Complex<float>>(
const size_t size, const Complex<float> *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
@ -673,11 +926,24 @@ template CUDA_LIB_EXPORT void CalEdgePadGrad3d<Complex<float>>(
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
Complex<float> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPadGrad3d<Complex<float>>(
const size_t size, const Complex<float> *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width,
const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, Complex<float> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalConstantPad3d<Complex<double>>(
const size_t size, const Complex<double> *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,
const int64_t padded_width, const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const Complex<double> *pad_value, Complex<double> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalCircularPad3d<Complex<double>>(
const size_t size, const Complex<double> *input, const int64_t old_depth, const int64_t old_height,
const int64_t old_width, const int64_t padded_depth, const int64_t padded_height, const int64_t padded_width,
const int64_t pad_head, const int64_t pad_top, const int64_t pad_left, const int64_t pad_back, const int64_t pad_down,
const int64_t pad_right, Complex<double> *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalReflectPad3d<Complex<double>>(
const size_t size, const Complex<double> *input, const int64_t num, const int64_t channels, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth, const int64_t padded_height,

View File

@ -19,6 +19,14 @@
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalCircularPad3d(const size_t size, const T *input, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth,
const int64_t padded_height, const int64_t padded_width, const int64_t pad_head,
const int64_t pad_top, const int64_t pad_left, const int64_t pad_back,
const int64_t pad_down, const int64_t pad_right, T *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalConstantPad3d(const size_t size, const T *input, const int64_t num, const int64_t channels,
const int64_t old_depth, const int64_t old_height, const int64_t old_width,
@ -66,4 +74,12 @@ CUDA_LIB_EXPORT void CalEdgePadGrad3d(const size_t size, T *input, const int64_t
const int64_t pad_left, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void CalCircularPadGrad3d(const size_t size, const T *input, const int64_t old_depth,
const int64_t old_height, const int64_t old_width, const int64_t padded_depth,
const int64_t padded_height, const int64_t padded_width,
const int64_t pad_head, const int64_t pad_top, const int64_t pad_left,
const int64_t pad_back, const int64_t pad_down, const int64_t pad_right,
T *output, const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_PAD_V3_IMPL_CUH_

View File

@ -18,7 +18,7 @@
namespace mindspore {
namespace kernel {
namespace {
const std::vector<std::string> mode_list = {ops::kReflect, ops::kEdge};
const std::vector<std::string> mode_list = {ops::kReflect, ops::kEdge, ops::kCircular};
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreatePadV3GradKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
@ -52,6 +52,30 @@ const std::vector<std::pair<KernelAttr, PadV3GradPtrCreatorFunc>> kernel_attr =
CreatePadV3GradKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreatePadV3GradKernelPtr<Complex<float>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreatePadV3GradKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreatePadV3GradKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreatePadV3GradKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreatePadV3GradKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreatePadV3GradKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreatePadV3GradKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreatePadV3GradKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreatePadV3GradKernelPtr<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreatePadV3GradKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreatePadV3GradKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreatePadV3GradKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreatePadV3GradKernelPtr<Complex<float>, int64_t>},
};
} // namespace

View File

@ -340,6 +340,7 @@ constexpr auto kHalfPixelCenters = "half_pixel_centers";
constexpr auto kConstant = "constant";
constexpr auto kReflect = "reflect";
constexpr auto kEdge = "edge";
constexpr auto kCircular = "circular";
constexpr auto kLr = "lr";
constexpr auto kL1 = "l1";
constexpr auto kL2 = "l2";

View File

@ -139,14 +139,15 @@ abstract::ShapePtr PadV3InferShape(const PrimitivePtr &primitive, const std::vec
std::vector<int64_t> paddings_val;
auto mode = GetValue<std::string>(primitive->GetAttr(kAttrMode));
if (mode != kConstant) {
(void)CheckAndConvertUtils::CheckInteger("input dims for edge or reflect mode", size, kGreaterEqual, kOtherMinDims,
prim_name);
}
if (mode == kReflect) {
ReflectModeCheck(prim_name, paddings_size, x_shape, paddings_arg, size);
} else if (mode == kEdge) {
(void)CheckAndConvertUtils::CheckInteger("input dims for edge mode", size, kLessEqual, kEdgeMaxDims, prim_name);
(void)CheckAndConvertUtils::CheckInteger("input dims for edge, reflect or circular mode", size, kGreaterEqual,
kOtherMinDims, prim_name);
if (mode == kReflect) {
ReflectModeCheck(prim_name, paddings_size, x_shape, paddings_arg, size);
} else {
(void)CheckAndConvertUtils::CheckInteger("input dims for edge mode", size, kLessEqual, kEdgeMaxDims, prim_name);
}
}
PaddingsSizeCheck(primitive, paddings_size, size);
for (int64_t i = 0; i < paddings_size; ++i) {
paddings_val.push_back(int64_t(paddings_arg[LongToSize(i)]));

View File

@ -2211,7 +2211,7 @@ class PadV3Grad(Primitive):
"""Initialize Padv3Grad"""
self.add_prim_attr("cust_aicpu", self.name)
self.init_prim_io_names(inputs=['x', 'paddings'], outputs=['y'])
validator.check_string(mode, ['reflect', 'edge'], 'mode', self.name)
validator.check_string(mode, ['reflect', 'edge', 'circular'], 'mode', self.name)
validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
self.set_const_input_indexes([1])
self.mode = mode

View File

@ -4275,7 +4275,7 @@ class PadV3(Primitive):
Args:
mode (str, optional): An optional string indicates padding mode,
support "constant", "reflect", "edge". Default: "constant".
support "constant", "reflect", "edge", "circular". Default: "constant".
paddings_contiguous (bool, optional): An optional bool value indicates if the padding is paddings_contiguous.
If true, paddings is arranged as [begin0, end0, begin1, end1, ...]
If false, paddings is arranged as [begin0, begin1, ..., end1, end2, ...]
@ -4296,13 +4296,14 @@ class PadV3(Primitive):
ValueError: If `mode` is not a str or not in support modes.
ValueError: If `mode` is "constant", the element's number of `paddings` not be even.
ValueError: If `mode` is "constant", the element's number of `paddings` large than input dim * 2.
ValueError: If `mode` is "edge" or "reflect", the element's number of `paddings` is not 2, 4 or 6.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 3,
ValueError: If `mode` is "edge" "reflect" or "circular", the element's number of `paddings` is not 2, 4 or 6.
ValueError: If `mode` is "edge" "reflect" or "circular", `x` dims equals 3,
the element's number of `paddings` is not 2.
ValueError: If `mode` is "edge" or "reflect", `x` dims equals 4,
ValueError: If `mode` is "edge" "reflect" or "circular", `x` dims equals 4,
the element's number of `paddings` is not 4.
ValueError: If `mode` is "edge" or "reflect", `x` dims smaller than 3.
ValueError: If `mode` is "edge", x dims bigger than 5.
ValueError: If `mode` is "circular", `x` dims equals 5, the element's number of `paddings` is not 6.
ValueError: If `mode` is "edge", "reflect" or "circular", `x` dims smaller than 3.
ValueError: If `mode` is "edge" or "circular", x dims bigger than 5.
ValueError: If `mode` is "reflect", x dims bigger than 4.
ValueError: If `mode` is "reflect", padding size bigger than the corresponding `x` dimension.
ValueError: After padding, output's shape number is not greater than 0.
@ -4346,7 +4347,7 @@ class PadV3(Primitive):
def __init__(self, mode='constant', paddings_contiguous=True):
"""Initialize PadV3"""
self.init_prim_io_names(inputs=['x', 'paddings', 'constant_value'], outputs=['y'])
validator.check_string(mode, ['constant', 'reflect', 'edge'], 'mode', self.name)
validator.check_string(mode, ['constant', 'reflect', 'edge', 'circular'], 'mode', self.name)
validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
self.mode = mode
self.paddings_contiguous = paddings_contiguous

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -38,6 +38,83 @@ class NetDynamic(nn.Cell):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_3d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="CPU", save_graphs=False)
x = Tensor(np.arange(9).reshape(1, 3, 3).astype(np.float32))
padding = Tensor((1, 2), dtype=mindspore.int64)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[2, 0, 1, 2, 0, 1],
[5, 3, 4, 5, 3, 4],
[8, 6, 7, 8, 6, 7]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_4d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="CPU", save_graphs=False)
x = Tensor(np.arange(9).reshape(1, 1, 3, 3).astype(np.float64))
padding = Tensor((1, -1, 1, 2), dtype=mindspore.int32)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 1, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[[7, 6, 7], [1, 0, 1], [4, 3, 4],
[7, 6, 7], [1, 0, 1], [4, 3, 4]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_5d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="CPU", save_graphs=False)
x = Tensor(np.arange(18).reshape(1, 1, 2, 3, 3).astype(np.float64))
padding = Tensor((0, 1, 1, -1, 0, -1), dtype=mindspore.int32)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 1, None, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[[[3, 4, 5, 3,],
[0, 1, 2, 0,],
[3, 4, 5, 3,]]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard

View File

@ -0,0 +1,92 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as Grad
class PadV3GradNet(nn.Cell):
def __init__(self, mode):
super(PadV3GradNet, self).__init__()
self.op = Grad.PadV3Grad(mode)
def construct(self, x, paddings):
return self.op(x, paddings)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3grad_circular_3d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
net = PadV3GradNet('circular')
x = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mindspore.float32).reshape((1, 2, 4))
paddings = Tensor(np.array([1, 1], dtype=np.int32))
output = net(x, paddings)
expect = np.array([[[6, 4], [14, 12]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3grad_circular_4d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
net = PadV3GradNet('circular')
x = Tensor(np.arange(18).reshape(1, 1, 3, 6).astype(np.float32))
paddings = Tensor(np.array([1, 2, 1, 0], dtype=np.int64))
output = net(x, paddings)
expect = np.array([[[[17., 19., 15.], [34., 38., 30.]]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_padv3grad_circular_5d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="CPU")
net = PadV3GradNet('circular')
x = Tensor(np.arange(80).reshape(1, 1, 5, 4, 4).astype(np.float64))
paddings = Tensor(np.array([1, 0, 1, 1, 2, 1], dtype=np.int64))
output = net(x, paddings)
expect = np.array([[[[[246., 252., 498.], [222., 228., 450.]],
[[164., 168., 332.], [148., 152., 300.]]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, output.asnumpy())

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -53,6 +53,83 @@ class NetDynamic(nn.Cell):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_3d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="GPU", save_graphs=False)
x = Tensor(np.arange(9).reshape(1, 3, 3).astype(np.float32))
padding = Tensor((1, 2), dtype=mindspore.int64)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[2, 0, 1, 2, 0, 1],
[5, 3, 4, 5, 3, 4],
[8, 6, 7, 8, 6, 7]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_4d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="GPU", save_graphs=False)
x = Tensor(np.arange(9).reshape(1, 1, 3, 3).astype(np.float64))
padding = Tensor((1, -1, 1, 2), dtype=mindspore.int32)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 1, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[[7, 6, 7], [1, 0, 1], [4, 3, 4],
[7, 6, 7], [1, 0, 1], [4, 3, 4]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3_circular_dynamic_shape_5d():
"""
Feature: test padv3 x and padding dynamic shape
Description: test padv3 dynamic shape
Expectation: Success
"""
context.set_context(device_target="GPU", save_graphs=False)
x = Tensor(np.arange(18).reshape(1, 1, 2, 3, 3).astype(np.float64))
padding = Tensor((0, 1, 1, -1, 0, -1), dtype=mindspore.int32)
net = NetDynamic('circular')
x_dyn = Tensor(shape=(1, 1, None, 3, None), dtype=x.dtype)
padding_dyn = Tensor(shape=(None,), dtype=padding.dtype)
net.set_inputs(x_dyn, padding_dyn)
out = net(x, padding)
expect = np.array([[[[[3, 4, 5, 3,],
[0, 1, 2, 0,],
[3, 4, 5, 3,]]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, out.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard

View File

@ -0,0 +1,92 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as Grad
class PadV3GradNet(nn.Cell):
def __init__(self, mode):
super(PadV3GradNet, self).__init__()
self.op = Grad.PadV3Grad(mode)
def construct(self, x, paddings):
return self.op(x, paddings)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3grad_circular_3d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
net = PadV3GradNet('circular')
x = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mindspore.float32).reshape((1, 2, 4))
paddings = Tensor(np.array([1, 1], dtype=np.int32))
output = net(x, paddings)
expect = np.array([[[6, 4], [14, 12]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3grad_circular_4d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
net = PadV3GradNet('circular')
x = Tensor(np.arange(18).reshape(1, 1, 3, 6).astype(np.float32))
paddings = Tensor(np.array([1, 2, 1, 0], dtype=np.int64))
output = net(x, paddings)
expect = np.array([[[[17., 19., 15.], [34., 38., 30.]]]]).astype(np.float32)
np.testing.assert_almost_equal(expect, output.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_padv3grad_circular_5d():
"""
Feature: test PadV3Grad
Description: test PadV3Grad circular mode.
Expectation: Success
"""
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
net = PadV3GradNet('circular')
x = Tensor(np.arange(80).reshape(1, 1, 5, 4, 4).astype(np.float64))
paddings = Tensor(np.array([1, 0, 1, 1, 2, 1], dtype=np.int64))
output = net(x, paddings)
expect = np.array([[[[[246., 252., 498.], [222., 228., 450.]],
[[164., 168., 332.], [148., 152., 300.]]]]]).astype(np.float64)
np.testing.assert_almost_equal(expect, output.asnumpy())