!34307 support ResizeBilinear and ResizeBilinearGrad op

Merge pull request !34307 from hanhuifeng/resize_bilinear
This commit is contained in:
i-robot 2022-06-14 02:50:34 +00:00 committed by Gitee
commit cbd4720e14
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
29 changed files with 1138 additions and 68 deletions

View File

@ -54,6 +54,7 @@ const std::map<std::string, std::string> opTypeAdapter = {{"ReLUV2", "ReluV2"},
{"TransposeNOD", "Transpose"},
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
{"ResizeBilinearGrad", "ResizeBilinearV2Grad"},
{"Split", "SplitD"},
{"HSwish", "HardSwish"},
{"HSwishGrad", "HardSwishGrad"},

View File

@ -59,6 +59,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
{"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"},
{"ParallelResizeBilinear", "SyncResizeBilinearV2"},
{"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"},
{"ResizeBilinearGrad", "ResizeBilinearV2Grad"},
{"HSwish", "HardSwish"},
{"HSwishGrad", "HardSwishGrad"},
{"CeLU", "CeluV2"},

View File

@ -18,8 +18,8 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T>
__global__ void ResizeBilinear(const T *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale,
const float w_scale, T *output) {
const int output_h, const int output_w, const int nchw, const int chw, const int hw,
const float h_scale, const float w_scale, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
@ -27,29 +27,59 @@ __global__ void ResizeBilinear(const T *input, const int n, const int c, const i
const int posw = pos % output_w;
const float posw_scaled = w_scale * posw;
const float posh_scaled = h_scale * posh;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), input_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), input_h - 1); // NOLINT
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), input_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), input_h - 1); // NOLINT
const float w_alpha = posw_scaled - w_low;
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - h_low;
const float h_beta = 1.0f - h_alpha;
const int input_start = input_h * input_w * (posn * c + posc);
const int input_start = input_h * input_w * (posn * c + posc);
const T p1 = input[input_start + (h_low * input_w) + w_low];
const T p2 = input[input_start + (h_low * input_w) + w_high];
const T p3 = input[input_start + (h_high * input_w) + w_low];
const T p4 = input[input_start + (h_high * input_w) + w_high];
output[pos] = (p1 * static_cast<T>(h_beta * w_beta)) + (p2 * static_cast<T>(h_beta * w_alpha))
+ (p3 * static_cast<T>(h_alpha * w_beta)) + (p4 * static_cast<T>(h_alpha * w_alpha));
output[pos] = (p1 * static_cast<T>(h_beta * w_beta)) + (p2 * static_cast<T>(h_beta * w_alpha)) +
(p3 * static_cast<T>(h_alpha * w_beta)) + (p4 * static_cast<T>(h_alpha * w_alpha));
}
return;
}
template <typename T>
__global__ void ResizeBilinear_HPC(const T *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const int nchw, const int chw, const int hw,
const float h_scale, const float w_scale, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
const int posh = pos / output_w % output_h;
const int posw = pos % output_w;
const float posw_scaled = (static_cast<float>(posw) + 0.5f) * w_scale - 0.5f;
const float posh_scaled = (static_cast<float>(posh) + 0.5f) * h_scale - 0.5f;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), input_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), input_h - 1); // NOLINT
const float w_alpha = posw_scaled - floorf(posw_scaled);
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - floorf(posh_scaled);
const float h_beta = 1.0f - h_alpha;
const int input_start = input_h * input_w * (posn * c + posc);
const T p1 = input[input_start + (h_low * input_w) + w_low];
const T p2 = input[input_start + (h_low * input_w) + w_high];
const T p3 = input[input_start + (h_high * input_w) + w_low];
const T p4 = input[input_start + (h_high * input_w) + w_high];
output[pos] = (p1 * static_cast<T>(h_beta * w_beta)) + (p2 * static_cast<T>(h_beta * w_alpha)) +
(p3 * static_cast<T>(h_alpha * w_beta)) + (p4 * static_cast<T>(h_alpha * w_alpha));
}
return;
}
// fp16 path
__global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale,
const float w_scale, half *output, float *interim) {
const int output_h, const int output_w, const int nchw, const int chw, const int hw,
const float h_scale, const float w_scale, half *output, float *interim) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
@ -57,10 +87,10 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c,
const int posw = pos % input_w;
const float posw_scaled = w_scale * posw;
const float posh_scaled = h_scale * posh;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const float w_alpha = posw_scaled - w_low;
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - h_low;
@ -70,7 +100,7 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c,
const float dp2 = h_beta * w_alpha * grad;
const float dp3 = h_alpha * w_beta * grad;
const float dp4 = h_alpha * w_alpha * grad;
const int output_start = output_h * output_w * (posn * c + posc);
const int output_start = output_h * output_w * (posn * c + posc);
atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1);
atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2);
atomicAdd(&interim[output_start + (h_high * output_w) + w_low], dp3);
@ -81,8 +111,8 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c,
// fp32 path
__global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale,
const float w_scale, float *output, float *interim) {
const int output_h, const int output_w, const int nchw, const int chw, const int hw,
const float h_scale, const float w_scale, float *output, float *interim) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
@ -90,10 +120,10 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c,
const int posw = pos % input_w;
const float posw_scaled = w_scale * posw;
const float posh_scaled = h_scale * posh;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const float w_alpha = posw_scaled - w_low;
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - h_low;
@ -103,7 +133,75 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c,
const float dp2 = h_beta * w_alpha * grad;
const float dp3 = h_alpha * w_beta * grad;
const float dp4 = h_alpha * w_alpha * grad;
const int output_start = output_h * output_w * (posn * c + posc);
const int output_start = output_h * output_w * (posn * c + posc);
atomicAdd(&output[output_start + (h_low * output_w) + w_low], dp1);
atomicAdd(&output[output_start + (h_low * output_w) + w_high], dp2);
atomicAdd(&output[output_start + (h_high * output_w) + w_low], dp3);
atomicAdd(&output[output_start + (h_high * output_w) + w_high], dp4);
}
return;
}
// fp16 path
__global__ void ResizeBilinearGrad_HPC(const half *input, const int n, const int c, const int input_h,
const int input_w, const int output_h, const int output_w, const int nchw,
const int chw, const int hw, const float h_scale, const float w_scale,
half *output, float *interim) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
const int posh = pos / input_w % input_h;
const int posw = pos % input_w;
const float posw_scaled = (static_cast<float>(posw) + 0.5f) * w_scale - 0.5f;
const float posh_scaled = (static_cast<float>(posh) + 0.5f) * h_scale - 0.5f;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const float w_alpha = posw_scaled - floorf(posw_scaled);
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - floorf(posh_scaled);
const float h_beta = 1.0f - h_alpha;
const float grad = static_cast<float>(input[pos]);
const float dp1 = h_beta * w_beta * grad;
const float dp2 = h_beta * w_alpha * grad;
const float dp3 = h_alpha * w_beta * grad;
const float dp4 = h_alpha * w_alpha * grad;
const int output_start = output_h * output_w * (posn * c + posc);
atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1);
atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2);
atomicAdd(&interim[output_start + (h_high * output_w) + w_low], dp3);
atomicAdd(&interim[output_start + (h_high * output_w) + w_high], dp4);
}
return;
}
// fp32 path
__global__ void ResizeBilinearGrad_HPC(const float *input, const int n, const int c, const int input_h,
const int input_w, const int output_h, const int output_w, const int nchw,
const int chw, const int hw, const float h_scale, const float w_scale,
float *output, float *interim) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
const int posn = pos / chw;
const int posc = pos / hw % c;
const int posh = pos / input_w % input_h;
const int posw = pos % input_w;
const float posw_scaled = (static_cast<float>(posw) + 0.5f) * w_scale - 0.5f;
const float posh_scaled = (static_cast<float>(posh) + 0.5f) * h_scale - 0.5f;
const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT
const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT
const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT
const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT
const float w_alpha = posw_scaled - floorf(posw_scaled);
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - floorf(posh_scaled);
const float h_beta = 1.0f - h_alpha;
const float grad = input[pos];
const float dp1 = h_beta * w_beta * grad;
const float dp2 = h_beta * w_alpha * grad;
const float dp3 = h_alpha * w_beta * grad;
const float dp4 = h_alpha * w_alpha * grad;
const int output_start = output_h * output_w * (posn * c + posc);
atomicAdd(&output[output_start + (h_low * output_w) + w_low], dp1);
atomicAdd(&output[output_start + (h_low * output_w) + w_high], dp2);
atomicAdd(&output[output_start + (h_high * output_w) + w_low], dp3);
@ -121,45 +219,62 @@ __global__ void ResizeBilinearGradPost(const int nchw, half *output, float *inte
template <typename T>
void CalResizeBilinear(const T *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const float h_scale, const float w_scale, T *output,
cudaStream_t cuda_stream) {
const int output_h, const int output_w, const float h_scale, const float w_scale,
const bool half_pixel_centers, T *output, cudaStream_t cuda_stream) {
const int nchw = n * c * output_h * output_w;
const int chw = c * output_h * output_w;
const int hw = output_h * output_w;
ResizeBilinear<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(input, n, c, input_h, input_w, output_h,
output_w, nchw, chw, hw, h_scale, w_scale, output);
if (half_pixel_centers) {
ResizeBilinear_HPC<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(
input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output);
} else {
ResizeBilinear<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(input, n, c, input_h, input_w, output_h, output_w,
nchw, chw, hw, h_scale, w_scale, output);
}
return;
}
void CalResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const float h_scale, const float w_scale, half *output, float *interim,
cudaStream_t cuda_stream) {
const int output_h, const int output_w, const float h_scale, const float w_scale,
const bool half_pixel_centers, half *output, float *interim, cudaStream_t cuda_stream) {
const int hw = input_h * input_w;
const int chw = c * hw;
const int nchw = n * chw;
const int output_num = n * c * output_h * output_w;
ResizeBilinearGrad<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(input, n, c, input_h, input_w, output_h,
output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
if (half_pixel_centers) {
ResizeBilinearGrad_HPC<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(
input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
} else {
ResizeBilinearGrad<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(
input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
}
ResizeBilinearGradPost<<<GET_BLOCKS(output_num), GET_THREADS, 0, cuda_stream>>>(output_num, output, interim);
return;
}
void CalResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, float *interim,
cudaStream_t cuda_stream) {
const int output_h, const int output_w, const float h_scale, const float w_scale,
const bool half_pixel_centers, float *output, float *interim, cudaStream_t cuda_stream) {
const int hw = input_h * input_w;
const int chw = c * hw;
const int nchw = n * chw;
ResizeBilinearGrad<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(input, n, c, input_h, input_w, output_h,
output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
if (half_pixel_centers) {
ResizeBilinearGrad_HPC<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(
input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
} else {
ResizeBilinearGrad<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(
input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim);
}
return;
}
template CUDA_LIB_EXPORT void CalResizeBilinear<float>(const float *input, const int n, const int c, const int input_h,
const int input_w, const int output_h, const int output_w,
const float h_scale, const float w_scale, float *output,
const float h_scale, const float w_scale,
const bool half_pixel_centers, float *output,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalResizeBilinear<half>(const half *input, const int n, const int c, const int input_h,
const int input_w, const int output_h, const int output_w,
const float h_scale, const float w_scale, half *output,
const float h_scale, const float w_scale,
const bool half_pixel_centers, half *output,
cudaStream_t cuda_stream);

View File

@ -21,13 +21,14 @@
template <typename T>
CUDA_LIB_EXPORT void CalResizeBilinear(const T *input, const int n_, const int c_, const int input_h_,
const int input_w_, const int output_h_, const int output_w_,
const float h_scale, const float w_scale, T *output, cudaStream_t cuda_stream);
const float h_scale, const float w_scale, const bool half_pixel_centers,
T *output, cudaStream_t cuda_stream);
CUDA_LIB_EXPORT void CalResizeBilinearGrad(const half *input, const int n_, const int c_, const int input_h_,
const int input_w_, const int output_h_, const int output_w_,
const float h_scale, const float w_scale, half *output, float *interim,
cudaStream_t cuda_stream);
const float h_scale, const float w_scale, const bool half_pixel_centers,
half *output, float *interim, cudaStream_t cuda_stream);
CUDA_LIB_EXPORT void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_,
const int input_w_, const int output_h_, const int output_w_,
const float h_scale, const float w_scale, float *output, float *interim,
cudaStream_t cuda_stream);
const float h_scale, const float w_scale, const bool half_pixel_centers,
float *output, float *interim, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RESIZE_BILINEAR_IMPL_CUH_

View File

@ -22,5 +22,21 @@ MS_REG_GPU_KERNEL_ONE(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat
ResizeBilinearGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeBilinearGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
ResizeBilinearV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
ResizeBilinearGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
ResizeBilinearV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
ResizeBilinearGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
ResizeBilinearV2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
ResizeBilinearGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
ResizeBilinearV2,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
ResizeBilinearGpuKernelMod, half)
} // namespace kernel
} // namespace mindspore

View File

@ -39,8 +39,8 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod {
T *output = GetDeviceAddress<T>(outputs, 0);
float h_scale = Scaling(input_h_, output_h_, align_corners_);
float w_scale = Scaling(input_w_, output_w_, align_corners_);
CalResizeBilinear(input, n_, c_, input_h_, input_w_, output_h_, output_w_, h_scale, w_scale, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalResizeBilinear(input, n_, c_, input_h_, input_w_, output_h_, output_w_, h_scale, w_scale, half_pixel_centers_,
output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -48,15 +48,17 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
kernel_node_ = kernel_node;
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num;
constexpr size_t kDynamicSizeInputNum = 2;
if (input_num != 1 && input_num != kDynamicSizeInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1 or " << kDynamicSizeInputNum
<< ", but got " << input_num;
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
}
std::vector<size_t> input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> output_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
std::vector<size_t> input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
std::vector<size_t> output_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ =
CHECK_SHAPE_NULL(input_shape, kernel_name, "input") || CHECK_SHAPE_NULL(output_shape, kernel_name, "output");
if (is_null_input_) {
@ -83,12 +85,14 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod {
output_size_ *= x;
}
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
half_pixel_centers_ = GetAttr<bool>(kernel_node, "half_pixel_centers");
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
align_corners_ = false;
half_pixel_centers_ = false;
is_null_input_ = false;
n_ = 0;
c_ = 0;
@ -117,6 +121,7 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
bool align_corners_;
bool half_pixel_centers_;
bool is_null_input_;
int n_;
int c_;

View File

@ -57,7 +57,7 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemsetAsync(interim, 0, workspace_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemsetAsync dx_interim failed");
CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, interim,
CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, half_pixel_centers_, dx, interim,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -73,9 +73,9 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
}
std::vector<size_t> dy_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
std::vector<size_t> dx_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
std::vector<size_t> dy_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
std::vector<size_t> dx_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(dy_shape, kernel_name, "dy") || CHECK_SHAPE_NULL(x_shape, kernel_name, "x") ||
CHECK_SHAPE_NULL(dx_shape, kernel_name, "dx");
if (is_null_input_) {
@ -110,12 +110,14 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(float);
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
half_pixel_centers_ = GetAttr<bool>(kernel_node, "half_pixel_centers");
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
align_corners_ = false;
half_pixel_centers_ = false;
is_null_input_ = false;
n_ = 0;
c_ = 0;
@ -145,6 +147,7 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
bool align_corners_;
bool half_pixel_centers_;
bool is_null_input_;
int n_;
int c_;

View File

@ -49,6 +49,8 @@
namespace mindspore {
namespace abstract {
PrimShapeDependMap &GetHostDependsMap() {
// Registration directly by the host_depends map will be deprecated and
// should be registered by the REGISTER_HOST_DEPENDS
using ShapeSet = std::set<int64_t>;
static const auto &kOneHot = prim::kPrimOneHot->name();
static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name();

View File

@ -89,8 +89,9 @@ class RegisterHostDependsHelper {
~RegisterHostDependsHelper() = default;
};
#define REGISTER_HOST_DEPENDS(name, depends) \
static auto helper_host_depends_##name = abstract::RegisterHostDependsHelper(name, depends);
// Processes such as InferShape need to obtain some inputs value on the host
#define REGISTER_HOST_DEPENDS(name, depends...) \
static auto helper_host_depends_##name = abstract::RegisterHostDependsHelper(name, ##depends);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_

View File

@ -43,6 +43,8 @@ constexpr auto kCdist = "Cdist";
constexpr auto kCdistGrad = "CdistGrad";
// image
constexpr auto kCropAndResizeGradBoxes = "CropAndResizeGradBoxes";
constexpr auto kResizeBilinearV2 = "ResizeBilinearV2";
constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad";
// Arithmetic
constexpr auto kScalarAdd = "ScalarAdd";
@ -475,6 +477,8 @@ GVAR_DEF(PrimitivePtr, kPrimSegmentSum, std::make_shared<Primitive>(kSegmentSum)
// image
GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared<Primitive>(kCropAndResizeGradBoxes));
GVAR_DEF(PrimitivePtr, kPrimResizeBilinearV2, std::make_shared<Primitive>(kResizeBilinearV2));
GVAR_DEF(PrimitivePtr, kPrimResizeBilinearGrad, std::make_shared<Primitive>(kResizeBilinearGrad));
// NN
GVAR_DEF(PrimitivePtr, kPrimCeLU, std::make_shared<Primitive>("CeLU"));

View File

@ -0,0 +1,98 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/resize_bilinear_grad.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void ResizeBilinearGrad::set_align_corners(const bool align_corners) {
(void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners));
}
bool ResizeBilinearGrad::get_align_corners() const {
auto value_ptr = GetAttr(kAlignCorners);
return GetValue<bool>(value_ptr);
}
void ResizeBilinearGrad::set_half_pixel_centers(const bool half_pixel_centers) {
(void)this->AddAttr(kHalfPixelCenters, api::MakeValue(half_pixel_centers));
}
bool ResizeBilinearGrad::get_half_pixel_centers() const {
auto value_ptr = GetAttr(kHalfPixelCenters);
return GetValue<bool>(value_ptr);
}
namespace {
abstract::ShapePtr ResizeBilinearGradInferShape(const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x = input_args[kOriginalImageIndex]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_x = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_x);
return shape_x;
}
TypePtr ResizeBilinearGradInferType(const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[kOriginalImageIndex]);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kOriginalImageIndex);
auto x_type = input_args[kOriginalImageIndex]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
if (!x_type->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input must be a Tensor, but got: " << x_type->ToString()
<< ".";
}
return x_type;
}
} // namespace
void ResizeBilinearGrad::Init(const bool align_corners, const bool half_pixel_centers) {
this->set_align_corners(align_corners);
this->set_half_pixel_centers(half_pixel_centers);
}
MIND_API_OPERATOR_IMPL(ResizeBilinearGrad, BaseOperator);
AbstractBasePtr ResizeBilinearGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
auto infer_type = ResizeBilinearGradInferType(primitive, input_args);
auto infer_shape = ResizeBilinearGradInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeBilinearGrad, prim::kPrimResizeBilinearGrad, ResizeBilinearGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_
#define MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeBilinearGrad = "ResizeBilinearGrad";
constexpr auto kOriginalImageIndex = 1;
/// \brief Resizes an image to a certain size using the bilinear interpolation.
/// Refer to Python API @ref mindspore.ops.ResizeBilinearGrad for more details.
class MIND_API ResizeBilinearGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ResizeBilinearGrad);
/// \brief Constructor.
ResizeBilinearGrad() : BaseOperator(kNameResizeBilinearGrad) { InitIOName({"grads", "original_image"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ResizeBilinearGrad for the inputs.
void Init(const bool align_corners = false, const bool half_pixel_centers = false);
/// \brief Set align_corners.
void set_align_corners(const bool align_corners);
/// \brief Set half_pixel_centers.
void set_half_pixel_centers(const bool half_pixel_centers);
/// \brief Get align_corners.
///
/// \return align_corners.
bool get_align_corners() const;
/// \brief Get half_pixel_centers.
///
/// \return half_pixel_centers.
bool get_half_pixel_centers() const;
};
abstract::AbstractBasePtr ResizeBilinearGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_

View File

@ -287,6 +287,7 @@ constexpr auto kIndexing = "indexing";
constexpr auto kModulated = "modulated";
constexpr auto kAdjoint = "adjoint";
constexpr auto kInplaceAlgo = "inplace_algo";
constexpr auto kHalfPixelCenters = "half_pixel_centers";
enum Index : size_t {
kInputIndex0 = 0,

View File

@ -0,0 +1,191 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/resize_bilinear_v2.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
void ResizeBilinearV2::set_align_corners(const bool align_corners) {
(void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners));
}
bool ResizeBilinearV2::get_align_corners() const {
auto value_ptr = GetAttr(kAlignCorners);
return GetValue<bool>(value_ptr);
}
void ResizeBilinearV2::set_half_pixel_centers(const bool half_pixel_centers) {
(void)this->AddAttr(kHalfPixelCenters, api::MakeValue(half_pixel_centers));
}
bool ResizeBilinearV2::get_half_pixel_centers() const {
auto value_ptr = GetAttr(kHalfPixelCenters);
return GetValue<bool>(value_ptr);
}
namespace {
void GetSizeValue(const abstract::AbstractBasePtr &size, std::vector<int64_t> *size_value,
std::vector<int64_t> *min_size, std::vector<int64_t> *max_size) {
MS_EXCEPTION_IF_NULL(size);
auto size_v = size->BuildValue();
MS_EXCEPTION_IF_NULL(size_v);
const int64_t size_size = 2;
auto prim_name = kNameResizeBilinearV2;
if (size->isa<abstract::AbstractTensor>()) {
if (size_v->isa<tensor::Tensor>()) {
*size_value = CheckAndConvertUtils::CheckTensorIntValue("size", size_v, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("size", *size_value, prim_name);
} else {
size_value->push_back(-1);
size_value->push_back(-1);
auto min_value = size->cast<abstract::AbstractTensorPtr>()->get_min_value();
auto max_value = size->cast<abstract::AbstractTensorPtr>()->get_max_value();
if (!min_value || !max_value) {
MS_LOG(INFO) << "inputs['size'] min or max value of " << prim_name << " is empty.";
return;
}
*min_size = GetValue<std::vector<int64_t>>(min_value);
*max_size = GetValue<std::vector<int64_t>>(max_value);
if (min_size->size() != size_size || max_size->size() != size_size) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", inputs['size'] min and max value size must be 2, but got min: "
<< min_size->size() << ", max: " << max_size->size() << ".";
}
}
} else if (size->isa<abstract::AbstractTuple>() || size->isa<abstract::AbstractList>()) {
*size_value = CheckAndConvertUtils::CheckIntOrTupleInt("size", size_v, prim_name);
(void)CheckAndConvertUtils::CheckPositiveVector("size", *size_value, prim_name);
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the "
<< "size"
<< " must be a Tensor or a tuple/list with all Int elements, but got " << size->ToString();
}
}
abstract::ShapePtr ResizeBilinearV2InferShape(const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
auto x_shape = x_shape_ptr->shape();
const int64_t shape_size = 4;
const int64_t size_size = 2;
(void)CheckAndConvertUtils::CheckInteger("the dimension of input_x", SizeToLong(x_shape.size()), kEqual, shape_size,
prim_name);
std::vector<int64_t> size_value;
std::vector<int64_t> min_size;
std::vector<int64_t> max_size;
auto input_num = input_args.size();
constexpr auto kInputNum = 2;
constexpr auto kConstSizeInputNum = 1;
if (input_num == kInputNum) {
auto size = input_args[1];
GetSizeValue(size, &size_value, &min_size, &max_size);
} else if (input_num == kConstSizeInputNum) {
// const size to attr by the convert_const_input_to_attr pass
size_value = GetValue<std::vector<int64_t>>(primitive->GetAttr(kSize));
} else {
MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the number of inputs must be "
<< " 1 or 2, but got " << input_num;
}
(void)CheckAndConvertUtils::CheckInteger("the dimension of size", SizeToLong(size_value.size()), kEqual, size_size,
prim_name);
std::vector<int64_t> output_shape;
std::vector<int64_t> min_shape;
std::vector<int64_t> max_shape;
output_shape.push_back(x_shape[0]);
output_shape.push_back(x_shape[1]);
output_shape.push_back(size_value[0]);
output_shape.push_back(size_value[1]);
// static shape:
if (!x_shape_ptr->IsDynamic() && !(size_value[0] < 0)) {
return std::make_shared<abstract::Shape>(output_shape);
}
// dynamic shape:
auto x_min_shape = x_shape_ptr->min_shape();
auto x_max_shape = x_shape_ptr->max_shape();
// The dynamic shape has no min_shape and max_shape
if ((x_shape_ptr->IsDynamic() && x_min_shape.empty()) || (size_value[0] < 0 && min_size.empty())) {
return std::make_shared<abstract::Shape>(output_shape);
}
// Get min_shape and max_shape of output_shape
if (x_shape_ptr->IsDynamic()) {
(void)CheckAndConvertUtils::CheckInteger("the dimension of min_shape", SizeToLong(x_min_shape.size()), kEqual,
SizeToLong(x_shape.size()), prim_name);
(void)CheckAndConvertUtils::CheckInteger("the dimension of max_shape", SizeToLong(x_max_shape.size()), kEqual,
SizeToLong(x_shape.size()), prim_name);
min_shape.push_back(x_min_shape[0]);
min_shape.push_back(x_min_shape[1]);
max_shape.push_back(x_max_shape[0]);
max_shape.push_back(x_max_shape[1]);
} else {
min_shape.push_back(x_shape[0]);
min_shape.push_back(x_shape[1]);
max_shape.push_back(x_shape[0]);
max_shape.push_back(x_shape[1]);
}
if (size_value[0] < 0) {
(void)CheckAndConvertUtils::CheckInteger("the dimension of min_size", SizeToLong(min_size.size()), kEqual,
size_size, prim_name);
(void)CheckAndConvertUtils::CheckInteger("the dimension of max_size", SizeToLong(max_size.size()), kEqual,
size_size, prim_name);
min_shape.push_back(min_size[0]);
min_shape.push_back(min_size[1]);
max_shape.push_back(max_size[0]);
max_shape.push_back(max_size[1]);
} else {
min_shape.push_back(size_value[0]);
min_shape.push_back(size_value[1]);
max_shape.push_back(size_value[0]);
max_shape.push_back(size_value[1]);
}
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
TypePtr ResizeBilinearV2InferType(const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, primitive->name());
}
} // namespace
void ResizeBilinearV2::Init(const bool align_corners, const bool half_pixel_centers) {
this->set_align_corners(align_corners);
this->set_half_pixel_centers(half_pixel_centers);
}
MIND_API_OPERATOR_IMPL(ResizeBilinearV2, BaseOperator);
AbstractBasePtr ResizeBilinearV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args) {
auto infer_type = ResizeBilinearV2InferType(primitive, input_args);
auto infer_shape = ResizeBilinearV2InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ResizeBilinearV2, prim::kPrimResizeBilinearV2, ResizeBilinearV2Infer, nullptr, true);
REGISTER_HOST_DEPENDS(kNameResizeBilinearV2, {1});
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_
#define MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameResizeBilinearV2 = "ResizeBilinearV2";
/// \brief Resizes an image to a certain size using the bilinear interpolation.
/// Refer to Python API @ref mindspore.ops.ResizeBilinearV2 for more details.
class MIND_API ResizeBilinearV2 : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ResizeBilinearV2);
/// \brief Constructor.
ResizeBilinearV2() : BaseOperator(kNameResizeBilinearV2) { InitIOName({"x", "size"}, {"output"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ResizeBilinearV2 for the inputs.
void Init(const bool align_corners = false, const bool half_pixel_centers = false);
/// \brief Set align_corners.
void set_align_corners(const bool align_corners);
/// \brief Set half_pixel_centers.
void set_half_pixel_centers(const bool half_pixel_centers);
/// \brief Get align_corners.
///
/// \return align_corners.
bool get_align_corners() const;
/// \brief Get half_pixel_centers.
///
/// \return half_pixel_centers.
bool get_half_pixel_centers() const;
};
abstract::AbstractBasePtr ResizeBilinearV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_

View File

@ -151,7 +151,7 @@ std::vector<int64_t> Slice::get_size() const {
return GetValue<std::vector<int64_t>>(value_ptr);
}
REGISTER_HOST_DEPENDS(kNameSlice, (std::set<int64_t>{1, 2}));
REGISTER_HOST_DEPENDS(kNameSlice, {1, 2});
REGISTER_PRIMITIVE_EVAL_IMPL(Slice, prim::kPrimSlice, SliceInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -22,6 +22,7 @@ from .grad_base import bprop_getters
from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import nn_ops as NN
def _get_matrix_diag_assist(x_shape, x_dtype):
@ -154,3 +155,15 @@ def get_bprop_ps_roi_pooling(self):
return dx, zeros_like(rois)
return bprop
@bprop_getters.register(NN.ResizeBilinearV2)
def get_bprop_resize_bilinear(self):
"""Grad definition for `ResizeBilinearV2` operation."""
resize_grad = G.ResizeBilinearGrad(self.align_corners, self.half_pixel_centers)
def bprop(x, size, out, dout):
dx = resize_grad(dout, x)
return dx, zeros_like(size)
return bprop

View File

@ -375,6 +375,7 @@ from .ceil import _ceil_tbe
from .ceil_ds import _ceil_ds_tbe
from .log1p import _log1p_tbe
from .resize_bilinear import _resize_bilinear_tbe
from .resize_bilinear_v2 import _resize_bilinear_v2_tbe
from .resize_bilinear_grad import _resize_bilinear_grad_tbe
from .flatten import _flatten_tbe
from .roi_align import _roi_align_tbe

View File

@ -22,8 +22,10 @@ resize_bilinear_grad_op_info = TBERegOp("ResizeBilinearGrad") \
.binfile_name("resize_bilinear_v2_grad.so") \
.compute_cost(10) \
.kernel_name("resize_bilinear_v2_grad") \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.partial_flag(True) \
.need_check_supported(True) \
.need_check_supported(False) \
.attr("align_corners", "optional", "bool", "all", "false") \
.attr("half_pixel_centers", "optional", "bool", "all", "false") \
.input(0, "grads", False, "required", "all") \

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResizeBilinear op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
resize_bilinear_v2_op_info = TBERegOp("ResizeBilinearV2") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("resize_bilinear_v2.so") \
.compute_cost(10) \
.kernel_name("resize_bilinear_v2") \
.partial_flag(True) \
.need_check_supported(False) \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.attr("align_corners", "optional", "bool", "all", "false") \
.attr("half_pixel_centers", "optional", "bool", "all", "false") \
.input(0, "x", False, "required", "all") \
.input(1, "size", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \
.dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \
.get_op_info()
@op_info_register(resize_bilinear_v2_op_info)
def _resize_bilinear_v2_tbe():
"""ResizeBilinear TBE register"""
return

View File

@ -223,6 +223,7 @@ from .nn_func import (
nll_loss,
cross_entropy,
grid_sample,
resize_bilinear,
)
from .linalg_func import (
svd,

View File

@ -868,6 +868,48 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
return NN.GridSampler3D(interpolation_mode, padding_mode, align_corners)(input_x, grid)
def resize_bilinear(x, size, align_corners=False, half_pixel_centers=False):
r"""
Resizes an image to a certain size using the bilinear interpolation.
The resizing only affects the lower two dimensions which represent the height and width.
Args:
x (Tensor): Image to be resized. Input images must be a 4-D tensor with shape
:math:`(batch, channels, height, width)`, with data type of float32 or float16.
size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements :math:`(new\_height, new\_width)`,
the new size of the images.
align_corners (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`,
which exactly aligns the 4 corners of images and resized images. If false,
rescale by :math:`new\_height / height`. Default: False.
half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False.
Default: False.
Returns:
Tensor, resized image. 4-D with shape :math:`(batch, channels, new\_height, new\_width)`,
with the same data type as input `x`.
Raises:
TypeError: If `align_corners` is not a bool.
TypeError: If `half_pixel_centers` is not a bool.
TypeError: If `align_corners` and `half_pixel_centers` are all True.
ValueError: If `half_pixel_centers` is True and device_target is CPU.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> output = resize_bilinear(x, (5, 5))
>>> print(output)
[[[[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
"""
return NN.ResizeBilinearV2(align_corners, half_pixel_centers)(x, size)
__all__ = [
'adaptive_avgpool2d',
'celu',
@ -882,6 +924,7 @@ __all__ = [
'pad',
'cross_entropy',
'grid_sample',
'resize_bilinear',
'nll_loss'
]
__all__.sort()

View File

@ -1797,7 +1797,7 @@ class GatherDGrad(PrimitiveWithInfer):
return grad_dtype
class ResizeBilinearGrad(PrimitiveWithInfer):
class ResizeBilinearGrad(Primitive):
"""Performs grad of ResizeBilinear operation."""
@prim_attr_register
@ -1808,18 +1808,12 @@ class ResizeBilinearGrad(PrimitiveWithInfer):
self.align_corners = validator.check_value_type("align_corners", align_corners, [bool], self.name)
self.half_pixel_centers = validator.check_value_type("half_pixel_centers",
half_pixel_centers, [bool], self.name)
self.init_prim_io_names(inputs=['grads', 'original_image'], outputs=['y'])
if half_pixel_centers and align_corners:
raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}")
target = context.get_context("device_target")
if half_pixel_centers and target.lower() != "ascend":
raise ValueError(f"Currently `half_pixel_centers`=True only support in Ascend device_target, "
f"but got {target}")
def infer_shape(self, dout_shape, orig_shape):
return orig_shape
def infer_dtype(self, dout_dtype, orig_type):
return orig_type
if half_pixel_centers and target.lower() == "cpu":
raise ValueError(f"Currently `half_pixel_centers`=True not support in cpu device_target")
class ResizeNearestNeighborGrad(Primitive):

View File

@ -3492,6 +3492,40 @@ class ResizeBilinear(PrimitiveWithInfer):
return input_dtype
class ResizeBilinearV2(Primitive):
r"""
Resizes an image to a certain size using the bilinear interpolation.
Refer to :func:`mindspore.ops.resize_bilinear` for more detail.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> output = ResizeBilinearV2(x, (5, 5))
>>> print(output)
[[[[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]]]]
"""
@prim_attr_register
def __init__(self, align_corners=False, half_pixel_centers=False):
"""Initialize ResizeBilinear."""
self.init_prim_io_names(inputs=['x', 'size'], outputs=['y'])
self.align_corners = validator.check_value_type("align_corners", align_corners, [bool], self.name)
self.half_pixel_centers = validator.check_value_type("half_pixel_centers",
half_pixel_centers, [bool], self.name)
if half_pixel_centers and align_corners:
raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}")
target = context.get_context("device_target")
if target.lower() == "cpu":
raise ValueError(f"Currently, for ResizeBilinearV2 CPU device is not supported!")
class OneHot(Primitive):
r"""
Computes a one-hot tensor.

View File

@ -0,0 +1,103 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context, Tensor
import mindspore.ops as ops
from mindspore import nn
from mindspore import ParameterTuple
class NetResizeBilinear(nn.Cell):
def construct(self, inputs, size):
return ops.resize_bilinear(inputs, size)
def case():
datatype = np.float16
input_tensor = Tensor(np.array(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype))
resize_nn = NetResizeBilinear()
output = resize_nn(input_tensor, (9, 9))
expected_output = np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3],
[0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4],
[0.2998, 0.3333, 0.3667, 0.4, 0.433, 0.4666, 0.5, 0.5, 0.5],
[0.4, 0.433, 0.4666, 0.5, 0.533, 0.5664, 0.6, 0.6, 0.6],
[0.5, 0.533, 0.5664, 0.5996, 0.6333, 0.6665, 0.6997, 0.6997, 0.6997],
[0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.766, 0.8, 0.7993, 0.8],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.9, 0.8994, 0.8994],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.8994, 0.8994, 0.8994],
[0.7, 0.7334, 0.7666, 0.8, 0.8325, 0.866,
0.8994, 0.8994, 0.8994]]]]).astype(datatype)
assert np.allclose(output.asnumpy(), expected_output, 1e-3, 1e-3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resize_bilinear_ascend():
"""
Feature: Test mindspore.ops.resize_bilinear on ascend.
Description: The size is a input
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
case()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
case()
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True, get_by_list=True, sens_param=True)
self.params = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
gradient_function = self.grad_op(self.net, self.params)
return gradient_function(*inputs)
def case_grad():
x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32)
out_grad = np.array([[[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1]]]]).astype(np.float32)
expect = np.array([[[[2.25, 0.75],
[0.75, 4.25]]]]).astype(np.float32)
net = NetResizeBilinear()
grad_net = GradNetWrtX(net)
output = grad_net(Tensor(x), (4, 4), Tensor(out_grad))
assert np.allclose(output[0][0].asnumpy(), expect, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resize_bilinear_grad_ascend():
"""
Feature: Test ResizeBilinearGrad on ascend.
Description: align_corners is False.
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
case_grad()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
case_grad()

View File

@ -0,0 +1,133 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context, Tensor
import mindspore.ops as ops
from mindspore import nn
def get_data():
datatype = np.float32
input_data = np.array(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype)
size = [9, 9]
expected_output = np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3],
[0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4],
[0.2998, 0.3333, 0.3667, 0.4, 0.433, 0.4666, 0.5, 0.5, 0.5],
[0.4, 0.433, 0.4666, 0.5, 0.533, 0.5664, 0.6, 0.6, 0.6],
[0.5, 0.533, 0.5664, 0.5996, 0.6333, 0.6665, 0.6997, 0.6997, 0.6997],
[0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.766, 0.8, 0.7993, 0.8],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.9, 0.8994, 0.8994],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.8994, 0.8994, 0.8994],
[0.7, 0.7334, 0.7666, 0.8, 0.8325, 0.866,
0.8994, 0.8994, 0.8994]]]]).astype(datatype)
return input_data, size, expected_output
class NetResizeBilinear(nn.Cell):
def construct(self, inputs, size, indices_input, axis):
unique_input_index, _ = ops.unique(indices_input)
inputs_dyn = ops.gather(inputs, unique_input_index, axis)
return ops.resize_bilinear(inputs_dyn, size)
def case_input_dyn(mode, device_target):
context.set_context(mode=mode, device_target=device_target)
input_data, size, expected = get_data()
resize_nn = NetResizeBilinear()
axis_input = 3
indices_input = np.array([i for i in range(input_data.shape[axis_input])])
output = resize_nn(Tensor(input_data), size, Tensor(indices_input), axis_input)
assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resize_bilinear_ascend():
"""
Feature: Test resize_bilinear on ascend.
Description: The shape of input is dynamic.
Expectation: Assert that results are consistent with expect.
"""
case_input_dyn(context.GRAPH_MODE, "Ascend")
case_input_dyn(context.PYNATIVE_MODE, "Ascend")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_resize_bilinear_gpu():
"""
Feature: Test resize_bilinear on GPU.
Description: The shape of input is dynamic.
Expectation: Assert that results are consistent with expect.
"""
case_input_dyn(context.GRAPH_MODE, "GPU")
case_input_dyn(context.PYNATIVE_MODE, "GPU")
class NetResizeBilinearSizeDyn(nn.Cell):
def construct(self, x, y, indices_x, indices_y, axis_x, axis_y):
unique_x_index, _ = ops.unique(indices_x)
x_dyn = ops.gather(x, unique_x_index, axis_x)
unique_y_index, _ = ops.unique(indices_y)
y_dyn = ops.gather(y, unique_y_index, axis_y)
size_dyn = ops.TensorShape()(y_dyn)
return ops.resize_bilinear(x_dyn, size_dyn)
def case_input_size_dyn(mode, device_target):
context.set_context(mode=mode, device_target=device_target)
x_data, size, expected = get_data()
y = np.random.rand(*size).astype(np.float32)
resize_nn = NetResizeBilinearSizeDyn()
axis_x = 3
indices_x = np.array([i for i in range(x_data.shape[axis_x])], dtype=np.int32)
axis_y = 1
indices_y = np.array([i for i in range(y.shape[axis_y])], dtype=np.int32)
output = resize_nn(Tensor(x_data), Tensor(y), Tensor(indices_x), Tensor(indices_y), axis_x, axis_y)
assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resize_bilinear_size_dyn_ascend():
"""
Feature: Test resize_bilinear on Ascend.
Description: The shape of input and size is dynamic.
Expectation: Assert that results are consistent with expect.
"""
case_input_size_dyn(context.GRAPH_MODE, "Ascend")
case_input_size_dyn(context.PYNATIVE_MODE, "Ascend")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_resize_bilinear_size_dyn_gpu():
"""
Feature: Test resize_bilinear on GPU.
Description: The shape of input and size is dynamic.
Expectation: Assert that results are consistent with expect.
"""
case_input_size_dyn(context.GRAPH_MODE, "GPU")
case_input_size_dyn(context.PYNATIVE_MODE, "GPU")

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class ResizeBilinearGradNet(nn.Cell):
def __init__(self, align_corners=False):
super(ResizeBilinearGradNet, self).__init__()
self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners)
def construct(self, dy, size, indices_dy, indices_size, axis):
unique_dy_index, _ = ops.unique(indices_dy)
unique_size_index, _ = ops.unique(indices_size)
dy_ = ops.gather(dy, unique_dy_index, axis)
size_ = ops.gather(size, unique_size_index, axis)
return self.rb1(dy_, size_)
def dyn_case():
dy = np.array([[[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1]]]]).astype(np.float32)
x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32)
expect = np.array([[[[2.25, 0.75],
[0.75, 4.25]]]]).astype(np.float32)
net = ResizeBilinearGradNet()
axis = 3
indices_dy = np.array([i for i in range(dy.shape[axis])])
indices_x = np.array([i for i in range(x.shape[axis])])
output = net(Tensor(dy), Tensor(x), Tensor(indices_dy), Tensor(indices_x), axis)
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resize_bilinear_grad_dyn_ascend():
"""
Feature: Test ResizeBilinearGrad on Ascend.
Description: The shape of inputs is dynamic.
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dyn_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_resize_bilinear_grad_dyn_gpu():
"""
Feature: Test ResizeBilinearGrad on GPU.
Description: The shape of inputs is dynamic.
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dyn_case()

View File

@ -23,9 +23,9 @@ from mindspore.ops.operations import _grad_ops as G
class ResizeBilinearGradNet(nn.Cell):
def __init__(self, align_corners=False):
def __init__(self, align_corners=False, half_pixel_centers=False):
super(ResizeBilinearGradNet, self).__init__()
self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners)
self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners, half_pixel_centers=half_pixel_centers)
def construct(self, dy, size):
return self.rb1(dy, size)
@ -91,3 +91,41 @@ def test_resize_bilinear_grad():
net = ResizeBilinearGradNet()
output = net(Tensor(dy), Tensor(x))
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_resize_bilinear_grad_half_pixel_centers():
"""
Feature: Test ResizeBilinearGrad on GPU.
Description: The half_pixel_centers is True.
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16)
x = np.array([[[[1.1, 2.2, 3.2, 2.5],
[3.3, 4.4, 5.7, 8.1],
[3.3, 4.4, 5.7, 8.1],
[3.3, 4.4, 5.7, 8.1]]]]).astype(np.float16)
expect = np.array([[[[0.25, 0.25, 0.5, 0.5],
[0.25, 0.25, 0.5, 0.5],
[0.75, 0.75, 1.0, 1.0],
[0.75, 0.75, 1.0, 1.0]]]], dtype=np.float16)
net = ResizeBilinearGradNet(half_pixel_centers=True)
output = net(Tensor(dy), Tensor(x))
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
x = np.array([[[[1.1, 2.2, 3.2, 2.5],
[3.3, 4.4, 5.7, 8.1],
[3.3, 4.4, 5.7, 8.1],
[3.3, 4.4, 5.7, 8.1]]]]).astype(np.float32)
expect = np.array([[[[0.25, 0.25, 0.5, 0.5],
[0.25, 0.25, 0.5, 0.5],
[0.75, 0.75, 1.0, 1.0],
[0.75, 0.75, 1.0, 1.0]]]], dtype=np.float32)
net = ResizeBilinearGradNet(half_pixel_centers=True)
output = net(Tensor(dy), Tensor(x))
assert np.all(output.asnumpy() == expect)

View File

@ -16,6 +16,7 @@ import numpy as np
import pytest
from mindspore import context, Tensor
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore import nn
@ -570,3 +571,30 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32):
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
assert np.all(abs(diff_align) < error)
class NetResizeBilinearFunc(nn.Cell):
def construct(self, inputs, size, align_corner=False, half_pixel_centers=False):
return ops.resize_bilinear(inputs, size, align_corner, half_pixel_centers)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_resize_nn_func_half_pixel_centers(datatype=np.float32):
"""
Feature: Test resize_bilinear on GPU.
Description: The half_pixel_centers is True.
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
resize_nn_func = NetResizeBilinearFunc()
output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True)
expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288,
0.36428574, 0.4],
[0.3, 0.3357143, 0.39285716, 0.45, 0.5071429,
0.56428576, 0.6],
[0.5, 0.5357143, 0.5928572, 0.65, 0.7071429,
0.76428574, 0.8]]]], dtype=datatype)
assert np.allclose(output.asnumpy(), expected_output)