From af70cc3b2ce3d27c53f4612e76eaf678307b444c Mon Sep 17 00:00:00 2001 From: hanhuifeng2020 Date: Thu, 12 May 2022 19:51:20 +0800 Subject: [PATCH] support ResizeBilinear and ResizeBilinearGrad op --- .../device/ascend/kernel/tbe/tbe_adapter.h | 1 + .../kernel/tbe/tiling/op_tiling_adapter.cc | 1 + .../cuda_ops/resize_bilinear_impl.cu | 189 +++++++++++++---- .../cuda_ops/resize_bilinear_impl.cuh | 11 +- .../kernel/nn/resize_bilinear_gpu_kernel.cc | 16 ++ .../kernel/nn/resize_bilinear_gpu_kernel.h | 17 +- .../nn/resize_bilinear_grad_gpu_kernel.h | 11 +- .../core/abstract/ops/primitive_infer_map.cc | 2 + .../core/abstract/ops/primitive_infer_map.h | 5 +- mindspore/core/ops/core_ops.h | 4 + .../core/ops/grad/resize_bilinear_grad.cc | 98 +++++++++ .../core/ops/grad/resize_bilinear_grad.h | 58 ++++++ mindspore/core/ops/op_name.h | 1 + mindspore/core/ops/resize_bilinear_v2.cc | 191 ++++++++++++++++++ mindspore/core/ops/resize_bilinear_v2.h | 57 ++++++ mindspore/core/ops/slice.cc | 2 +- .../mindspore/ops/_grad/grad_inner_ops.py | 13 ++ .../mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/resize_bilinear_grad.py | 4 +- .../ops/_op_impl/tbe/resize_bilinear_v2.py | 43 ++++ .../python/mindspore/ops/function/__init__.py | 1 + .../python/mindspore/ops/function/nn_func.py | 43 ++++ .../mindspore/ops/operations/_grad_ops.py | 14 +- .../python/mindspore/ops/operations/nn_ops.py | 34 ++++ tests/st/ops/ascend/test_resize_bilinear.py | 103 ++++++++++ .../dynamic_shape/test_resize_bilinear_dyn.py | 133 ++++++++++++ .../test_resize_bilinear_grad_dyn.py | 83 ++++++++ .../ops/gpu/test_resize_bilinear_grad_op.py | 42 +++- tests/st/ops/gpu/test_resize_bilinear_op.py | 28 +++ 29 files changed, 1138 insertions(+), 68 deletions(-) create mode 100644 mindspore/core/ops/grad/resize_bilinear_grad.cc create mode 100644 mindspore/core/ops/grad/resize_bilinear_grad.h create mode 100644 mindspore/core/ops/resize_bilinear_v2.cc create mode 100644 mindspore/core/ops/resize_bilinear_v2.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py create mode 100644 tests/st/ops/ascend/test_resize_bilinear.py create mode 100644 tests/st/ops/dynamic_shape/test_resize_bilinear_dyn.py create mode 100644 tests/st/ops/dynamic_shape/test_resize_bilinear_grad_dyn.py diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_adapter.h b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_adapter.h index 9422a30507b..74a9b403109 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_adapter.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_adapter.h @@ -54,6 +54,7 @@ const std::map opTypeAdapter = {{"ReLUV2", "ReluV2"}, {"TransposeNOD", "Transpose"}, {"ParallelResizeBilinear", "SyncResizeBilinearV2"}, {"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"}, + {"ResizeBilinearGrad", "ResizeBilinearV2Grad"}, {"Split", "SplitD"}, {"HSwish", "HardSwish"}, {"HSwishGrad", "HardSwishGrad"}, diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tiling/op_tiling_adapter.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tiling/op_tiling_adapter.cc index 6e6c6904118..3b71ebcc67b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tiling/op_tiling_adapter.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tiling/op_tiling_adapter.cc @@ -59,6 +59,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type) {"DynamicResizeNearestNeighbor", "ResizeNearestNeighborV2"}, {"ParallelResizeBilinear", "SyncResizeBilinearV2"}, {"ParallelResizeBilinearGrad", "SyncResizeBilinearV2Grad"}, + {"ResizeBilinearGrad", "ResizeBilinearV2Grad"}, {"HSwish", "HardSwish"}, {"HSwishGrad", "HardSwishGrad"}, {"CeLU", "CeluV2"}, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cu index 131d6d15e8b..5a4c1c04d0d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cu @@ -18,8 +18,8 @@ #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" template __global__ void ResizeBilinear(const T *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale, - const float w_scale, T *output) { + const int output_h, const int output_w, const int nchw, const int chw, const int hw, + const float h_scale, const float w_scale, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { const int posn = pos / chw; const int posc = pos / hw % c; @@ -27,29 +27,59 @@ __global__ void ResizeBilinear(const T *input, const int n, const int c, const i const int posw = pos % output_w; const float posw_scaled = w_scale * posw; const float posh_scaled = h_scale * posh; - const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT - const int w_high = min(static_cast(ceilf(posw_scaled)), input_w - 1); // NOLINT - const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT - const int h_high = min(static_cast(ceilf(posh_scaled)), input_h - 1); // NOLINT + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), input_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), input_h - 1); // NOLINT const float w_alpha = posw_scaled - w_low; const float w_beta = 1.0f - w_alpha; const float h_alpha = posh_scaled - h_low; const float h_beta = 1.0f - h_alpha; - const int input_start = input_h * input_w * (posn * c + posc); + const int input_start = input_h * input_w * (posn * c + posc); const T p1 = input[input_start + (h_low * input_w) + w_low]; const T p2 = input[input_start + (h_low * input_w) + w_high]; const T p3 = input[input_start + (h_high * input_w) + w_low]; const T p4 = input[input_start + (h_high * input_w) + w_high]; - output[pos] = (p1 * static_cast(h_beta * w_beta)) + (p2 * static_cast(h_beta * w_alpha)) - + (p3 * static_cast(h_alpha * w_beta)) + (p4 * static_cast(h_alpha * w_alpha)); + output[pos] = (p1 * static_cast(h_beta * w_beta)) + (p2 * static_cast(h_beta * w_alpha)) + + (p3 * static_cast(h_alpha * w_beta)) + (p4 * static_cast(h_alpha * w_alpha)); + } + return; +} + +template +__global__ void ResizeBilinear_HPC(const T *input, const int n, const int c, const int input_h, const int input_w, + const int output_h, const int output_w, const int nchw, const int chw, const int hw, + const float h_scale, const float w_scale, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { + const int posn = pos / chw; + const int posc = pos / hw % c; + const int posh = pos / output_w % output_h; + const int posw = pos % output_w; + const float posw_scaled = (static_cast(posw) + 0.5f) * w_scale - 0.5f; + const float posh_scaled = (static_cast(posh) + 0.5f) * h_scale - 0.5f; + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), input_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), input_h - 1); // NOLINT + const float w_alpha = posw_scaled - floorf(posw_scaled); + const float w_beta = 1.0f - w_alpha; + const float h_alpha = posh_scaled - floorf(posh_scaled); + const float h_beta = 1.0f - h_alpha; + const int input_start = input_h * input_w * (posn * c + posc); + const T p1 = input[input_start + (h_low * input_w) + w_low]; + const T p2 = input[input_start + (h_low * input_w) + w_high]; + const T p3 = input[input_start + (h_high * input_w) + w_low]; + const T p4 = input[input_start + (h_high * input_w) + w_high]; + output[pos] = (p1 * static_cast(h_beta * w_beta)) + (p2 * static_cast(h_beta * w_alpha)) + + (p3 * static_cast(h_alpha * w_beta)) + (p4 * static_cast(h_alpha * w_alpha)); } return; } // fp16 path __global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale, - const float w_scale, half *output, float *interim) { + const int output_h, const int output_w, const int nchw, const int chw, const int hw, + const float h_scale, const float w_scale, half *output, float *interim) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { const int posn = pos / chw; const int posc = pos / hw % c; @@ -57,10 +87,10 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const int posw = pos % input_w; const float posw_scaled = w_scale * posw; const float posh_scaled = h_scale * posh; - const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT - const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT - const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT - const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT const float w_alpha = posw_scaled - w_low; const float w_beta = 1.0f - w_alpha; const float h_alpha = posh_scaled - h_low; @@ -70,7 +100,7 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const float dp2 = h_beta * w_alpha * grad; const float dp3 = h_alpha * w_beta * grad; const float dp4 = h_alpha * w_alpha * grad; - const int output_start = output_h * output_w * (posn * c + posc); + const int output_start = output_h * output_w * (posn * c + posc); atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1); atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2); atomicAdd(&interim[output_start + (h_high * output_w) + w_low], dp3); @@ -81,8 +111,8 @@ __global__ void ResizeBilinearGrad(const half *input, const int n, const int c, // fp32 path __global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale, - const float w_scale, float *output, float *interim) { + const int output_h, const int output_w, const int nchw, const int chw, const int hw, + const float h_scale, const float w_scale, float *output, float *interim) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { const int posn = pos / chw; const int posc = pos / hw % c; @@ -90,10 +120,10 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int posw = pos % input_w; const float posw_scaled = w_scale * posw; const float posh_scaled = h_scale * posh; - const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT - const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT - const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT - const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT const float w_alpha = posw_scaled - w_low; const float w_beta = 1.0f - w_alpha; const float h_alpha = posh_scaled - h_low; @@ -103,7 +133,75 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const float dp2 = h_beta * w_alpha * grad; const float dp3 = h_alpha * w_beta * grad; const float dp4 = h_alpha * w_alpha * grad; - const int output_start = output_h * output_w * (posn * c + posc); + const int output_start = output_h * output_w * (posn * c + posc); + atomicAdd(&output[output_start + (h_low * output_w) + w_low], dp1); + atomicAdd(&output[output_start + (h_low * output_w) + w_high], dp2); + atomicAdd(&output[output_start + (h_high * output_w) + w_low], dp3); + atomicAdd(&output[output_start + (h_high * output_w) + w_high], dp4); + } + return; +} + +// fp16 path +__global__ void ResizeBilinearGrad_HPC(const half *input, const int n, const int c, const int input_h, + const int input_w, const int output_h, const int output_w, const int nchw, + const int chw, const int hw, const float h_scale, const float w_scale, + half *output, float *interim) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { + const int posn = pos / chw; + const int posc = pos / hw % c; + const int posh = pos / input_w % input_h; + const int posw = pos % input_w; + const float posw_scaled = (static_cast(posw) + 0.5f) * w_scale - 0.5f; + const float posh_scaled = (static_cast(posh) + 0.5f) * h_scale - 0.5f; + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT + const float w_alpha = posw_scaled - floorf(posw_scaled); + const float w_beta = 1.0f - w_alpha; + const float h_alpha = posh_scaled - floorf(posh_scaled); + const float h_beta = 1.0f - h_alpha; + const float grad = static_cast(input[pos]); + const float dp1 = h_beta * w_beta * grad; + const float dp2 = h_beta * w_alpha * grad; + const float dp3 = h_alpha * w_beta * grad; + const float dp4 = h_alpha * w_alpha * grad; + const int output_start = output_h * output_w * (posn * c + posc); + atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1); + atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2); + atomicAdd(&interim[output_start + (h_high * output_w) + w_low], dp3); + atomicAdd(&interim[output_start + (h_high * output_w) + w_high], dp4); + } + return; +} + +// fp32 path +__global__ void ResizeBilinearGrad_HPC(const float *input, const int n, const int c, const int input_h, + const int input_w, const int output_h, const int output_w, const int nchw, + const int chw, const int hw, const float h_scale, const float w_scale, + float *output, float *interim) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { + const int posn = pos / chw; + const int posc = pos / hw % c; + const int posh = pos / input_w % input_h; + const int posw = pos % input_w; + const float posw_scaled = (static_cast(posw) + 0.5f) * w_scale - 0.5f; + const float posh_scaled = (static_cast(posh) + 0.5f) * h_scale - 0.5f; + const int w_low = max(static_cast(floorf(posw_scaled)), 0); // NOLINT + const int w_high = min(static_cast(ceilf(posw_scaled)), output_w - 1); // NOLINT + const int h_low = max(static_cast(floorf(posh_scaled)), 0); // NOLINT + const int h_high = min(static_cast(ceilf(posh_scaled)), output_h - 1); // NOLINT + const float w_alpha = posw_scaled - floorf(posw_scaled); + const float w_beta = 1.0f - w_alpha; + const float h_alpha = posh_scaled - floorf(posh_scaled); + const float h_beta = 1.0f - h_alpha; + const float grad = input[pos]; + const float dp1 = h_beta * w_beta * grad; + const float dp2 = h_beta * w_alpha * grad; + const float dp3 = h_alpha * w_beta * grad; + const float dp4 = h_alpha * w_alpha * grad; + const int output_start = output_h * output_w * (posn * c + posc); atomicAdd(&output[output_start + (h_low * output_w) + w_low], dp1); atomicAdd(&output[output_start + (h_low * output_w) + w_high], dp2); atomicAdd(&output[output_start + (h_high * output_w) + w_low], dp3); @@ -121,45 +219,62 @@ __global__ void ResizeBilinearGradPost(const int nchw, half *output, float *inte template void CalResizeBilinear(const T *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const float h_scale, const float w_scale, T *output, - cudaStream_t cuda_stream) { + const int output_h, const int output_w, const float h_scale, const float w_scale, + const bool half_pixel_centers, T *output, cudaStream_t cuda_stream) { const int nchw = n * c * output_h * output_w; const int chw = c * output_h * output_w; const int hw = output_h * output_w; - ResizeBilinear<<>>(input, n, c, input_h, input_w, output_h, - output_w, nchw, chw, hw, h_scale, w_scale, output); + if (half_pixel_centers) { + ResizeBilinear_HPC<<>>( + input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output); + } else { + ResizeBilinear<<>>(input, n, c, input_h, input_w, output_h, output_w, + nchw, chw, hw, h_scale, w_scale, output); + } return; } void CalResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const float h_scale, const float w_scale, half *output, float *interim, - cudaStream_t cuda_stream) { + const int output_h, const int output_w, const float h_scale, const float w_scale, + const bool half_pixel_centers, half *output, float *interim, cudaStream_t cuda_stream) { const int hw = input_h * input_w; const int chw = c * hw; const int nchw = n * chw; const int output_num = n * c * output_h * output_w; - ResizeBilinearGrad<<>>(input, n, c, input_h, input_w, output_h, - output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + if (half_pixel_centers) { + ResizeBilinearGrad_HPC<<>>( + input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + } else { + ResizeBilinearGrad<<>>( + input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + } ResizeBilinearGradPost<<>>(output_num, output, interim); return; } void CalResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, - const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, float *interim, - cudaStream_t cuda_stream) { + const int output_h, const int output_w, const float h_scale, const float w_scale, + const bool half_pixel_centers, float *output, float *interim, cudaStream_t cuda_stream) { const int hw = input_h * input_w; const int chw = c * hw; const int nchw = n * chw; - ResizeBilinearGrad<<>>(input, n, c, input_h, input_w, output_h, - output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + if (half_pixel_centers) { + ResizeBilinearGrad_HPC<<>>( + input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + } else { + ResizeBilinearGrad<<>>( + input, n, c, input_h, input_w, output_h, output_w, nchw, chw, hw, h_scale, w_scale, output, interim); + } return; } template CUDA_LIB_EXPORT void CalResizeBilinear(const float *input, const int n, const int c, const int input_h, const int input_w, const int output_h, const int output_w, - const float h_scale, const float w_scale, float *output, + const float h_scale, const float w_scale, + const bool half_pixel_centers, float *output, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalResizeBilinear(const half *input, const int n, const int c, const int input_h, const int input_w, const int output_h, const int output_w, - const float h_scale, const float w_scale, half *output, + const float h_scale, const float w_scale, + const bool half_pixel_centers, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cuh index 710b5d03886..def9b04bdf7 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/resize_bilinear_impl.cuh @@ -21,13 +21,14 @@ template CUDA_LIB_EXPORT void CalResizeBilinear(const T *input, const int n_, const int c_, const int input_h_, const int input_w_, const int output_h_, const int output_w_, - const float h_scale, const float w_scale, T *output, cudaStream_t cuda_stream); + const float h_scale, const float w_scale, const bool half_pixel_centers, + T *output, cudaStream_t cuda_stream); CUDA_LIB_EXPORT void CalResizeBilinearGrad(const half *input, const int n_, const int c_, const int input_h_, const int input_w_, const int output_h_, const int output_w_, - const float h_scale, const float w_scale, half *output, float *interim, - cudaStream_t cuda_stream); + const float h_scale, const float w_scale, const bool half_pixel_centers, + half *output, float *interim, cudaStream_t cuda_stream); CUDA_LIB_EXPORT void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_, const int output_h_, const int output_w_, - const float h_scale, const float w_scale, float *output, float *interim, - cudaStream_t cuda_stream); + const float h_scale, const float w_scale, const bool half_pixel_centers, + float *output, float *interim, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_RESIZE_BILINEAR_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.cc index 4e1248fdbf0..85c0b812a5b 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.cc @@ -22,5 +22,21 @@ MS_REG_GPU_KERNEL_ONE(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat ResizeBilinearGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ResizeBilinearGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE( + ResizeBilinearV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ResizeBilinearGpuKernelMod, float) +MS_REG_GPU_KERNEL_ONE( + ResizeBilinearV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ResizeBilinearGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE( + ResizeBilinearV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + ResizeBilinearGpuKernelMod, float) +MS_REG_GPU_KERNEL_ONE( + ResizeBilinearV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + ResizeBilinearGpuKernelMod, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.h index af9f735cb04..7cf7a4ac9c5 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_gpu_kernel.h @@ -39,8 +39,8 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod { T *output = GetDeviceAddress(outputs, 0); float h_scale = Scaling(input_h_, output_h_, align_corners_); float w_scale = Scaling(input_w_, output_w_, align_corners_); - CalResizeBilinear(input, n_, c_, input_h_, input_w_, output_h_, output_w_, h_scale, w_scale, output, - reinterpret_cast(stream_ptr)); + CalResizeBilinear(input, n_, c_, input_h_, input_w_, output_h_, output_w_, h_scale, w_scale, half_pixel_centers_, + output, reinterpret_cast(stream_ptr)); return true; } @@ -48,15 +48,17 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod { auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); kernel_node_ = kernel_node; - if (input_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num; + constexpr size_t kDynamicSizeInputNum = 2; + if (input_num != 1 && input_num != kDynamicSizeInputNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1 or " << kDynamicSizeInputNum + << ", but got " << input_num; } size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num; } - std::vector input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector output_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); + std::vector input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); + std::vector output_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0); is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name, "input") || CHECK_SHAPE_NULL(output_shape, kernel_name, "output"); if (is_null_input_) { @@ -83,12 +85,14 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod { output_size_ *= x; } align_corners_ = GetAttr(kernel_node, "align_corners"); + half_pixel_centers_ = GetAttr(kernel_node, "half_pixel_centers"); InitSizeLists(); return true; } void ResetResource() noexcept override { align_corners_ = false; + half_pixel_centers_ = false; is_null_input_ = false; n_ = 0; c_ = 0; @@ -117,6 +121,7 @@ class ResizeBilinearGpuKernelMod : public DeprecatedNativeGpuKernelMod { } bool align_corners_; + bool half_pixel_centers_; bool is_null_input_; int n_; int c_; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_grad_gpu_kernel.h index 8189455a788..9e9fb58ed7a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/resize_bilinear_grad_gpu_kernel.h @@ -57,7 +57,7 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMemsetAsync(interim, 0, workspace_size_, reinterpret_cast(stream_ptr)), "cudaMemsetAsync dx_interim failed"); - CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, interim, + CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, half_pixel_centers_, dx, interim, reinterpret_cast(stream_ptr)); return true; } @@ -73,9 +73,9 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { if (output_num != 1) { MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num; } - std::vector dy_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector dx_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); + std::vector dy_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); + std::vector x_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1); + std::vector dx_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0); is_null_input_ = CHECK_SHAPE_NULL(dy_shape, kernel_name, "dy") || CHECK_SHAPE_NULL(x_shape, kernel_name, "x") || CHECK_SHAPE_NULL(dx_shape, kernel_name, "dx"); if (is_null_input_) { @@ -110,12 +110,14 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { } workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(float); align_corners_ = GetAttr(kernel_node, "align_corners"); + half_pixel_centers_ = GetAttr(kernel_node, "half_pixel_centers"); InitSizeLists(); return true; } void ResetResource() noexcept override { align_corners_ = false; + half_pixel_centers_ = false; is_null_input_ = false; n_ = 0; c_ = 0; @@ -145,6 +147,7 @@ class ResizeBilinearGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { } bool align_corners_; + bool half_pixel_centers_; bool is_null_input_; int n_; int c_; diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 101375c330d..a52d691ac70 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -49,6 +49,8 @@ namespace mindspore { namespace abstract { PrimShapeDependMap &GetHostDependsMap() { + // Registration directly by the host_depends map will be deprecated and + // should be registered by the REGISTER_HOST_DEPENDS using ShapeSet = std::set; static const auto &kOneHot = prim::kPrimOneHot->name(); static const auto &kDropoutGenMask = prim::kPrimDropoutGenMask->name(); diff --git a/mindspore/core/abstract/ops/primitive_infer_map.h b/mindspore/core/abstract/ops/primitive_infer_map.h index b282d52b73c..f16fff84b30 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.h +++ b/mindspore/core/abstract/ops/primitive_infer_map.h @@ -89,8 +89,9 @@ class RegisterHostDependsHelper { ~RegisterHostDependsHelper() = default; }; -#define REGISTER_HOST_DEPENDS(name, depends) \ - static auto helper_host_depends_##name = abstract::RegisterHostDependsHelper(name, depends); +// Processes such as InferShape need to obtain some inputs value on the host +#define REGISTER_HOST_DEPENDS(name, depends...) \ + static auto helper_host_depends_##name = abstract::RegisterHostDependsHelper(name, ##depends); } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index dbfc6d9ff3d..0f0037c4727 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -43,6 +43,8 @@ constexpr auto kCdist = "Cdist"; constexpr auto kCdistGrad = "CdistGrad"; // image constexpr auto kCropAndResizeGradBoxes = "CropAndResizeGradBoxes"; +constexpr auto kResizeBilinearV2 = "ResizeBilinearV2"; +constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad"; // Arithmetic constexpr auto kScalarAdd = "ScalarAdd"; @@ -472,6 +474,8 @@ GVAR_DEF(PrimitivePtr, kPrimSegmentSum, std::make_shared(kSegmentSum) // image GVAR_DEF(PrimitivePtr, kPrimCropAndResizeGradBoxes, std::make_shared(kCropAndResizeGradBoxes)); +GVAR_DEF(PrimitivePtr, kPrimResizeBilinearV2, std::make_shared(kResizeBilinearV2)); +GVAR_DEF(PrimitivePtr, kPrimResizeBilinearGrad, std::make_shared(kResizeBilinearGrad)); // NN GVAR_DEF(PrimitivePtr, kPrimCeLU, std::make_shared("CeLU")); diff --git a/mindspore/core/ops/grad/resize_bilinear_grad.cc b/mindspore/core/ops/grad/resize_bilinear_grad.cc new file mode 100644 index 00000000000..16ae096911b --- /dev/null +++ b/mindspore/core/ops/grad/resize_bilinear_grad.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/grad/resize_bilinear_grad.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +void ResizeBilinearGrad::set_align_corners(const bool align_corners) { + (void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners)); +} + +bool ResizeBilinearGrad::get_align_corners() const { + auto value_ptr = GetAttr(kAlignCorners); + return GetValue(value_ptr); +} + +void ResizeBilinearGrad::set_half_pixel_centers(const bool half_pixel_centers) { + (void)this->AddAttr(kHalfPixelCenters, api::MakeValue(half_pixel_centers)); +} + +bool ResizeBilinearGrad::get_half_pixel_centers() const { + auto value_ptr = GetAttr(kHalfPixelCenters); + return GetValue(value_ptr); +} + +namespace { +abstract::ShapePtr ResizeBilinearGradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x = input_args[kOriginalImageIndex]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto shape_x = x->cast(); + MS_EXCEPTION_IF_NULL(shape_x); + return shape_x; +} + +TypePtr ResizeBilinearGradInferType(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + MS_EXCEPTION_IF_NULL(input_args[kOriginalImageIndex]); + (void)CheckAndConvertUtils::CheckArgs(prim_name, input_args, kOriginalImageIndex); + auto x_type = input_args[kOriginalImageIndex]->BuildType(); + MS_EXCEPTION_IF_NULL(x_type); + if (!x_type->isa()) { + MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input must be a Tensor, but got: " << x_type->ToString() + << "."; + } + return x_type; +} +} // namespace + +void ResizeBilinearGrad::Init(const bool align_corners, const bool half_pixel_centers) { + this->set_align_corners(align_corners); + this->set_half_pixel_centers(half_pixel_centers); +} + +MIND_API_OPERATOR_IMPL(ResizeBilinearGrad, BaseOperator); + +AbstractBasePtr ResizeBilinearGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto infer_type = ResizeBilinearGradInferType(primitive, input_args); + auto infer_shape = ResizeBilinearGradInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(ResizeBilinearGrad, prim::kPrimResizeBilinearGrad, ResizeBilinearGradInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/resize_bilinear_grad.h b/mindspore/core/ops/grad/resize_bilinear_grad.h new file mode 100644 index 00000000000..17c785d72de --- /dev/null +++ b/mindspore/core/ops/grad/resize_bilinear_grad.h @@ -0,0 +1,58 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_ +#define MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_ +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameResizeBilinearGrad = "ResizeBilinearGrad"; +constexpr auto kOriginalImageIndex = 1; +/// \brief Resizes an image to a certain size using the bilinear interpolation. +/// Refer to Python API @ref mindspore.ops.ResizeBilinearGrad for more details. +class MIND_API ResizeBilinearGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(ResizeBilinearGrad); + /// \brief Constructor. + ResizeBilinearGrad() : BaseOperator(kNameResizeBilinearGrad) { InitIOName({"grads", "original_image"}, {"y"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ResizeBilinearGrad for the inputs. + void Init(const bool align_corners = false, const bool half_pixel_centers = false); + /// \brief Set align_corners. + void set_align_corners(const bool align_corners); + /// \brief Set half_pixel_centers. + void set_half_pixel_centers(const bool half_pixel_centers); + /// \brief Get align_corners. + /// + /// \return align_corners. + bool get_align_corners() const; + /// \brief Get half_pixel_centers. + /// + /// \return half_pixel_centers. + bool get_half_pixel_centers() const; +}; +abstract::AbstractBasePtr ResizeBilinearGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_BILINEAR_GRAD_H_ diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index bc64ebb5e9f..00a725744b8 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -287,6 +287,7 @@ constexpr auto kIndexing = "indexing"; constexpr auto kModulated = "modulated"; constexpr auto kAdjoint = "adjoint"; constexpr auto kInplaceAlgo = "inplace_algo"; +constexpr auto kHalfPixelCenters = "half_pixel_centers"; enum Index : size_t { kInputIndex0 = 0, diff --git a/mindspore/core/ops/resize_bilinear_v2.cc b/mindspore/core/ops/resize_bilinear_v2.cc new file mode 100644 index 00000000000..a9b9c9acf4e --- /dev/null +++ b/mindspore/core/ops/resize_bilinear_v2.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/resize_bilinear_v2.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +void ResizeBilinearV2::set_align_corners(const bool align_corners) { + (void)this->AddAttr(kAlignCorners, api::MakeValue(align_corners)); +} + +bool ResizeBilinearV2::get_align_corners() const { + auto value_ptr = GetAttr(kAlignCorners); + return GetValue(value_ptr); +} + +void ResizeBilinearV2::set_half_pixel_centers(const bool half_pixel_centers) { + (void)this->AddAttr(kHalfPixelCenters, api::MakeValue(half_pixel_centers)); +} + +bool ResizeBilinearV2::get_half_pixel_centers() const { + auto value_ptr = GetAttr(kHalfPixelCenters); + return GetValue(value_ptr); +} + +namespace { +void GetSizeValue(const abstract::AbstractBasePtr &size, std::vector *size_value, + std::vector *min_size, std::vector *max_size) { + MS_EXCEPTION_IF_NULL(size); + auto size_v = size->BuildValue(); + MS_EXCEPTION_IF_NULL(size_v); + const int64_t size_size = 2; + auto prim_name = kNameResizeBilinearV2; + if (size->isa()) { + if (size_v->isa()) { + *size_value = CheckAndConvertUtils::CheckTensorIntValue("size", size_v, prim_name); + (void)CheckAndConvertUtils::CheckPositiveVector("size", *size_value, prim_name); + } else { + size_value->push_back(-1); + size_value->push_back(-1); + auto min_value = size->cast()->get_min_value(); + auto max_value = size->cast()->get_max_value(); + if (!min_value || !max_value) { + MS_LOG(INFO) << "inputs['size'] min or max value of " << prim_name << " is empty."; + return; + } + *min_size = GetValue>(min_value); + *max_size = GetValue>(max_value); + if (min_size->size() != size_size || max_size->size() != size_size) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", inputs['size'] min and max value size must be 2, but got min: " + << min_size->size() << ", max: " << max_size->size() << "."; + } + } + } else if (size->isa() || size->isa()) { + *size_value = CheckAndConvertUtils::CheckIntOrTupleInt("size", size_v, prim_name); + (void)CheckAndConvertUtils::CheckPositiveVector("size", *size_value, prim_name); + } else { + MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " + << "size" + << " must be a Tensor or a tuple/list with all Int elements, but got " << size->ToString(); + } +} + +abstract::ShapePtr ResizeBilinearV2InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0); + auto x_shape = x_shape_ptr->shape(); + const int64_t shape_size = 4; + const int64_t size_size = 2; + (void)CheckAndConvertUtils::CheckInteger("the dimension of input_x", SizeToLong(x_shape.size()), kEqual, shape_size, + prim_name); + std::vector size_value; + std::vector min_size; + std::vector max_size; + auto input_num = input_args.size(); + constexpr auto kInputNum = 2; + constexpr auto kConstSizeInputNum = 1; + if (input_num == kInputNum) { + auto size = input_args[1]; + GetSizeValue(size, &size_value, &min_size, &max_size); + } else if (input_num == kConstSizeInputNum) { + // const size to attr by the convert_const_input_to_attr pass + size_value = GetValue>(primitive->GetAttr(kSize)); + } else { + MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the number of inputs must be " + << " 1 or 2, but got " << input_num; + } + (void)CheckAndConvertUtils::CheckInteger("the dimension of size", SizeToLong(size_value.size()), kEqual, size_size, + prim_name); + std::vector output_shape; + std::vector min_shape; + std::vector max_shape; + output_shape.push_back(x_shape[0]); + output_shape.push_back(x_shape[1]); + output_shape.push_back(size_value[0]); + output_shape.push_back(size_value[1]); + // static shape: + if (!x_shape_ptr->IsDynamic() && !(size_value[0] < 0)) { + return std::make_shared(output_shape); + } + // dynamic shape: + auto x_min_shape = x_shape_ptr->min_shape(); + auto x_max_shape = x_shape_ptr->max_shape(); + // The dynamic shape has no min_shape and max_shape + if ((x_shape_ptr->IsDynamic() && x_min_shape.empty()) || (size_value[0] < 0 && min_size.empty())) { + return std::make_shared(output_shape); + } + // Get min_shape and max_shape of output_shape + if (x_shape_ptr->IsDynamic()) { + (void)CheckAndConvertUtils::CheckInteger("the dimension of min_shape", SizeToLong(x_min_shape.size()), kEqual, + SizeToLong(x_shape.size()), prim_name); + (void)CheckAndConvertUtils::CheckInteger("the dimension of max_shape", SizeToLong(x_max_shape.size()), kEqual, + SizeToLong(x_shape.size()), prim_name); + min_shape.push_back(x_min_shape[0]); + min_shape.push_back(x_min_shape[1]); + max_shape.push_back(x_max_shape[0]); + max_shape.push_back(x_max_shape[1]); + } else { + min_shape.push_back(x_shape[0]); + min_shape.push_back(x_shape[1]); + max_shape.push_back(x_shape[0]); + max_shape.push_back(x_shape[1]); + } + if (size_value[0] < 0) { + (void)CheckAndConvertUtils::CheckInteger("the dimension of min_size", SizeToLong(min_size.size()), kEqual, + size_size, prim_name); + (void)CheckAndConvertUtils::CheckInteger("the dimension of max_size", SizeToLong(max_size.size()), kEqual, + size_size, prim_name); + min_shape.push_back(min_size[0]); + min_shape.push_back(min_size[1]); + max_shape.push_back(max_size[0]); + max_shape.push_back(max_size[1]); + } else { + min_shape.push_back(size_value[0]); + min_shape.push_back(size_value[1]); + max_shape.push_back(size_value[0]); + max_shape.push_back(size_value[1]); + } + return std::make_shared(output_shape, min_shape, max_shape); +} + +TypePtr ResizeBilinearV2InferType(const PrimitivePtr &primitive, + const std::vector &input_args) { + const std::set valid_types = {kFloat16, kFloat32}; + return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_types, primitive->name()); +} +} // namespace + +void ResizeBilinearV2::Init(const bool align_corners, const bool half_pixel_centers) { + this->set_align_corners(align_corners); + this->set_half_pixel_centers(half_pixel_centers); +} + +MIND_API_OPERATOR_IMPL(ResizeBilinearV2, BaseOperator); + +AbstractBasePtr ResizeBilinearV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto infer_type = ResizeBilinearV2InferType(primitive, input_args); + auto infer_shape = ResizeBilinearV2InferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(ResizeBilinearV2, prim::kPrimResizeBilinearV2, ResizeBilinearV2Infer, nullptr, true); + +REGISTER_HOST_DEPENDS(kNameResizeBilinearV2, {1}); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/resize_bilinear_v2.h b/mindspore/core/ops/resize_bilinear_v2.h new file mode 100644 index 00000000000..052fb2dafcf --- /dev/null +++ b/mindspore/core/ops/resize_bilinear_v2.h @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_ +#define MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_ +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameResizeBilinearV2 = "ResizeBilinearV2"; +/// \brief Resizes an image to a certain size using the bilinear interpolation. +/// Refer to Python API @ref mindspore.ops.ResizeBilinearV2 for more details. +class MIND_API ResizeBilinearV2 : public BaseOperator { + public: + MIND_API_BASE_MEMBER(ResizeBilinearV2); + /// \brief Constructor. + ResizeBilinearV2() : BaseOperator(kNameResizeBilinearV2) { InitIOName({"x", "size"}, {"output"}); } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ResizeBilinearV2 for the inputs. + void Init(const bool align_corners = false, const bool half_pixel_centers = false); + /// \brief Set align_corners. + void set_align_corners(const bool align_corners); + /// \brief Set half_pixel_centers. + void set_half_pixel_centers(const bool half_pixel_centers); + /// \brief Get align_corners. + /// + /// \return align_corners. + bool get_align_corners() const; + /// \brief Get half_pixel_centers. + /// + /// \return half_pixel_centers. + bool get_half_pixel_centers() const; +}; +abstract::AbstractBasePtr ResizeBilinearV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_RESIZE_BILINEAR_V2_H_ diff --git a/mindspore/core/ops/slice.cc b/mindspore/core/ops/slice.cc index 171047b0f18..0eb2213b335 100644 --- a/mindspore/core/ops/slice.cc +++ b/mindspore/core/ops/slice.cc @@ -151,7 +151,7 @@ std::vector Slice::get_size() const { return GetValue>(value_ptr); } -REGISTER_HOST_DEPENDS(kNameSlice, (std::set{1, 2})); +REGISTER_HOST_DEPENDS(kNameSlice, {1, 2}); REGISTER_PRIMITIVE_EVAL_IMPL(Slice, prim::kPrimSlice, SliceInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py b/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py index 2c71828fc0b..1d553a8143d 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_inner_ops.py @@ -22,6 +22,7 @@ from .grad_base import bprop_getters from ..operations import _grad_ops as G from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like +from ..operations import nn_ops as NN def _get_matrix_diag_assist(x_shape, x_dtype): @@ -154,3 +155,15 @@ def get_bprop_ps_roi_pooling(self): return dx, zeros_like(rois) return bprop + + +@bprop_getters.register(NN.ResizeBilinearV2) +def get_bprop_resize_bilinear(self): + """Grad definition for `ResizeBilinearV2` operation.""" + resize_grad = G.ResizeBilinearGrad(self.align_corners, self.half_pixel_centers) + + def bprop(x, size, out, dout): + dx = resize_grad(dout, x) + return dx, zeros_like(size) + + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index 1c30c3e6022..fe13bc2c590 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -375,6 +375,7 @@ from .ceil import _ceil_tbe from .ceil_ds import _ceil_ds_tbe from .log1p import _log1p_tbe from .resize_bilinear import _resize_bilinear_tbe +from .resize_bilinear_v2 import _resize_bilinear_v2_tbe from .resize_bilinear_grad import _resize_bilinear_grad_tbe from .flatten import _flatten_tbe from .roi_align import _roi_align_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py b/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py index 0951434ada0..49f3200a889 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py @@ -22,8 +22,10 @@ resize_bilinear_grad_op_info = TBERegOp("ResizeBilinearGrad") \ .binfile_name("resize_bilinear_v2_grad.so") \ .compute_cost(10) \ .kernel_name("resize_bilinear_v2_grad") \ + .dynamic_compile_static(True) \ + .dynamic_shape(True) \ .partial_flag(True) \ - .need_check_supported(True) \ + .need_check_supported(False) \ .attr("align_corners", "optional", "bool", "all", "false") \ .attr("half_pixel_centers", "optional", "bool", "all", "false") \ .input(0, "grads", False, "required", "all") \ diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py b/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py new file mode 100644 index 00000000000..dc094f0ef41 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ResizeBilinear op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +resize_bilinear_v2_op_info = TBERegOp("ResizeBilinearV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("resize_bilinear_v2.so") \ + .compute_cost(10) \ + .kernel_name("resize_bilinear_v2") \ + .partial_flag(True) \ + .need_check_supported(False) \ + .dynamic_compile_static(True) \ + .dynamic_shape(True) \ + .attr("align_corners", "optional", "bool", "all", "false") \ + .attr("half_pixel_centers", "optional", "bool", "all", "false") \ + .input(0, "x", False, "required", "all") \ + .input(1, "size", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \ + .get_op_info() + + +@op_info_register(resize_bilinear_v2_op_info) +def _resize_bilinear_v2_tbe(): + """ResizeBilinear TBE register""" + return diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 36d17ca6c58..c60b1db2636 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -223,6 +223,7 @@ from .nn_func import ( nll_loss, cross_entropy, grid_sample, + resize_bilinear, ) from .linalg_func import ( svd, diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 21725d3e6f6..a7ad1468c04 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -868,6 +868,48 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero return NN.GridSampler3D(interpolation_mode, padding_mode, align_corners)(input_x, grid) +def resize_bilinear(x, size, align_corners=False, half_pixel_centers=False): + r""" + Resizes an image to a certain size using the bilinear interpolation. + + The resizing only affects the lower two dimensions which represent the height and width. + + Args: + x (Tensor): Image to be resized. Input images must be a 4-D tensor with shape + :math:`(batch, channels, height, width)`, with data type of float32 or float16. + size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements :math:`(new\_height, new\_width)`, + the new size of the images. + align_corners (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`, + which exactly aligns the 4 corners of images and resized images. If false, + rescale by :math:`new\_height / height`. Default: False. + half_pixel_centers (bool): Whether half pixel center. If set to True, `align_corners` should be False. + Default: False. + + Returns: + Tensor, resized image. 4-D with shape :math:`(batch, channels, new\_height, new\_width)`, + with the same data type as input `x`. + + Raises: + TypeError: If `align_corners` is not a bool. + TypeError: If `half_pixel_centers` is not a bool. + TypeError: If `align_corners` and `half_pixel_centers` are all True. + ValueError: If `half_pixel_centers` is True and device_target is CPU. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32) + >>> output = resize_bilinear(x, (5, 5)) + >>> print(output) + [[[[1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + """ + return NN.ResizeBilinearV2(align_corners, half_pixel_centers)(x, size) + + __all__ = [ 'adaptive_avgpool2d', 'celu', @@ -882,6 +924,7 @@ __all__ = [ 'pad', 'cross_entropy', 'grid_sample', + 'resize_bilinear', 'nll_loss' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 22ae2240340..41415540897 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -1797,7 +1797,7 @@ class GatherDGrad(PrimitiveWithInfer): return grad_dtype -class ResizeBilinearGrad(PrimitiveWithInfer): +class ResizeBilinearGrad(Primitive): """Performs grad of ResizeBilinear operation.""" @prim_attr_register @@ -1808,18 +1808,12 @@ class ResizeBilinearGrad(PrimitiveWithInfer): self.align_corners = validator.check_value_type("align_corners", align_corners, [bool], self.name) self.half_pixel_centers = validator.check_value_type("half_pixel_centers", half_pixel_centers, [bool], self.name) + self.init_prim_io_names(inputs=['grads', 'original_image'], outputs=['y']) if half_pixel_centers and align_corners: raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}") target = context.get_context("device_target") - if half_pixel_centers and target.lower() != "ascend": - raise ValueError(f"Currently `half_pixel_centers`=True only support in Ascend device_target, " - f"but got {target}") - - def infer_shape(self, dout_shape, orig_shape): - return orig_shape - - def infer_dtype(self, dout_dtype, orig_type): - return orig_type + if half_pixel_centers and target.lower() == "cpu": + raise ValueError(f"Currently `half_pixel_centers`=True not support in cpu device_target") class ResizeNearestNeighborGrad(Primitive): diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index d679dd17497..6ab8c2c0a67 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -3501,6 +3501,40 @@ class ResizeBilinear(PrimitiveWithInfer): return input_dtype +class ResizeBilinearV2(Primitive): + r""" + Resizes an image to a certain size using the bilinear interpolation. + + Refer to :func:`mindspore.ops.resize_bilinear` for more detail. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32) + >>> output = ResizeBilinearV2(x, (5, 5)) + >>> print(output) + [[[[1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.] + [1. 2. 3. 4. 5.]]]] + """ + + @prim_attr_register + def __init__(self, align_corners=False, half_pixel_centers=False): + """Initialize ResizeBilinear.""" + self.init_prim_io_names(inputs=['x', 'size'], outputs=['y']) + self.align_corners = validator.check_value_type("align_corners", align_corners, [bool], self.name) + self.half_pixel_centers = validator.check_value_type("half_pixel_centers", + half_pixel_centers, [bool], self.name) + if half_pixel_centers and align_corners: + raise ValueError(f"If half_pixel_centers is True, align_corners must be False, but got {align_corners}") + target = context.get_context("device_target") + if target.lower() == "cpu": + raise ValueError(f"Currently, for ResizeBilinearV2 CPU device is not supported!") + + class OneHot(Primitive): r""" Computes a one-hot tensor. diff --git a/tests/st/ops/ascend/test_resize_bilinear.py b/tests/st/ops/ascend/test_resize_bilinear.py new file mode 100644 index 00000000000..0125c8272ac --- /dev/null +++ b/tests/st/ops/ascend/test_resize_bilinear.py @@ -0,0 +1,103 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +from mindspore import context, Tensor +import mindspore.ops as ops +from mindspore import nn +from mindspore import ParameterTuple + + +class NetResizeBilinear(nn.Cell): + def construct(self, inputs, size): + return ops.resize_bilinear(inputs, size) + + +def case(): + datatype = np.float16 + input_tensor = Tensor(np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype)) + resize_nn = NetResizeBilinear() + output = resize_nn(input_tensor, (9, 9)) + expected_output = np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3], + [0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4], + [0.2998, 0.3333, 0.3667, 0.4, 0.433, 0.4666, 0.5, 0.5, 0.5], + [0.4, 0.433, 0.4666, 0.5, 0.533, 0.5664, 0.6, 0.6, 0.6], + [0.5, 0.533, 0.5664, 0.5996, 0.6333, 0.6665, 0.6997, 0.6997, 0.6997], + [0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.766, 0.8, 0.7993, 0.8], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.9, 0.8994, 0.8994], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.8994, 0.8994, 0.8994], + [0.7, 0.7334, 0.7666, 0.8, 0.8325, 0.866, + 0.8994, 0.8994, 0.8994]]]]).astype(datatype) + assert np.allclose(output.asnumpy(), expected_output, 1e-3, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resize_bilinear_ascend(): + """ + Feature: Test mindspore.ops.resize_bilinear on ascend. + Description: The size is a input + Expectation: Assert that results are consistent with expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + case() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + case() + + +class GradNetWrtX(nn.Cell): + def __init__(self, net): + super(GradNetWrtX, self).__init__() + self.net = net + self.grad_op = ops.GradOperation(get_all=True, get_by_list=True, sens_param=True) + self.params = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + gradient_function = self.grad_op(self.net, self.params) + return gradient_function(*inputs) + + +def case_grad(): + x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32) + out_grad = np.array([[[[1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]]]).astype(np.float32) + expect = np.array([[[[2.25, 0.75], + [0.75, 4.25]]]]).astype(np.float32) + net = NetResizeBilinear() + grad_net = GradNetWrtX(net) + output = grad_net(Tensor(x), (4, 4), Tensor(out_grad)) + assert np.allclose(output[0][0].asnumpy(), expect, 1e-4, 1e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resize_bilinear_grad_ascend(): + """ + Feature: Test ResizeBilinearGrad on ascend. + Description: align_corners is False. + Expectation: Assert that results are consistent with expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + case_grad() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + case_grad() diff --git a/tests/st/ops/dynamic_shape/test_resize_bilinear_dyn.py b/tests/st/ops/dynamic_shape/test_resize_bilinear_dyn.py new file mode 100644 index 00000000000..8606d677d09 --- /dev/null +++ b/tests/st/ops/dynamic_shape/test_resize_bilinear_dyn.py @@ -0,0 +1,133 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest + +from mindspore import context, Tensor +import mindspore.ops as ops +from mindspore import nn + + +def get_data(): + datatype = np.float32 + input_data = np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype) + size = [9, 9] + expected_output = np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3], + [0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4], + [0.2998, 0.3333, 0.3667, 0.4, 0.433, 0.4666, 0.5, 0.5, 0.5], + [0.4, 0.433, 0.4666, 0.5, 0.533, 0.5664, 0.6, 0.6, 0.6], + [0.5, 0.533, 0.5664, 0.5996, 0.6333, 0.6665, 0.6997, 0.6997, 0.6997], + [0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.766, 0.8, 0.7993, 0.8], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.9, 0.8994, 0.8994], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.866, 0.8994, 0.8994, 0.8994], + [0.7, 0.7334, 0.7666, 0.8, 0.8325, 0.866, + 0.8994, 0.8994, 0.8994]]]]).astype(datatype) + return input_data, size, expected_output + + +class NetResizeBilinear(nn.Cell): + def construct(self, inputs, size, indices_input, axis): + unique_input_index, _ = ops.unique(indices_input) + inputs_dyn = ops.gather(inputs, unique_input_index, axis) + return ops.resize_bilinear(inputs_dyn, size) + + +def case_input_dyn(mode, device_target): + context.set_context(mode=mode, device_target=device_target) + input_data, size, expected = get_data() + + resize_nn = NetResizeBilinear() + axis_input = 3 + indices_input = np.array([i for i in range(input_data.shape[axis_input])]) + output = resize_nn(Tensor(input_data), size, Tensor(indices_input), axis_input) + assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resize_bilinear_ascend(): + """ + Feature: Test resize_bilinear on ascend. + Description: The shape of input is dynamic. + Expectation: Assert that results are consistent with expect. + """ + case_input_dyn(context.GRAPH_MODE, "Ascend") + case_input_dyn(context.PYNATIVE_MODE, "Ascend") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_resize_bilinear_gpu(): + """ + Feature: Test resize_bilinear on GPU. + Description: The shape of input is dynamic. + Expectation: Assert that results are consistent with expect. + """ + case_input_dyn(context.GRAPH_MODE, "GPU") + case_input_dyn(context.PYNATIVE_MODE, "GPU") + + +class NetResizeBilinearSizeDyn(nn.Cell): + def construct(self, x, y, indices_x, indices_y, axis_x, axis_y): + unique_x_index, _ = ops.unique(indices_x) + x_dyn = ops.gather(x, unique_x_index, axis_x) + unique_y_index, _ = ops.unique(indices_y) + y_dyn = ops.gather(y, unique_y_index, axis_y) + size_dyn = ops.TensorShape()(y_dyn) + return ops.resize_bilinear(x_dyn, size_dyn) + + +def case_input_size_dyn(mode, device_target): + context.set_context(mode=mode, device_target=device_target) + x_data, size, expected = get_data() + y = np.random.rand(*size).astype(np.float32) + resize_nn = NetResizeBilinearSizeDyn() + axis_x = 3 + indices_x = np.array([i for i in range(x_data.shape[axis_x])], dtype=np.int32) + axis_y = 1 + indices_y = np.array([i for i in range(y.shape[axis_y])], dtype=np.int32) + output = resize_nn(Tensor(x_data), Tensor(y), Tensor(indices_x), Tensor(indices_y), axis_x, axis_y) + assert np.allclose(output.asnumpy(), expected, 1e-3, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resize_bilinear_size_dyn_ascend(): + """ + Feature: Test resize_bilinear on Ascend. + Description: The shape of input and size is dynamic. + Expectation: Assert that results are consistent with expect. + """ + case_input_size_dyn(context.GRAPH_MODE, "Ascend") + case_input_size_dyn(context.PYNATIVE_MODE, "Ascend") + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_resize_bilinear_size_dyn_gpu(): + """ + Feature: Test resize_bilinear on GPU. + Description: The shape of input and size is dynamic. + Expectation: Assert that results are consistent with expect. + """ + case_input_size_dyn(context.GRAPH_MODE, "GPU") + case_input_size_dyn(context.PYNATIVE_MODE, "GPU") diff --git a/tests/st/ops/dynamic_shape/test_resize_bilinear_grad_dyn.py b/tests/st/ops/dynamic_shape/test_resize_bilinear_grad_dyn.py new file mode 100644 index 00000000000..4c79650afbc --- /dev/null +++ b/tests/st/ops/dynamic_shape/test_resize_bilinear_grad_dyn.py @@ -0,0 +1,83 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class ResizeBilinearGradNet(nn.Cell): + def __init__(self, align_corners=False): + super(ResizeBilinearGradNet, self).__init__() + self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners) + + def construct(self, dy, size, indices_dy, indices_size, axis): + unique_dy_index, _ = ops.unique(indices_dy) + unique_size_index, _ = ops.unique(indices_size) + dy_ = ops.gather(dy, unique_dy_index, axis) + size_ = ops.gather(size, unique_size_index, axis) + return self.rb1(dy_, size_) + + +def dyn_case(): + dy = np.array([[[[1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]]]).astype(np.float32) + x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32) + expect = np.array([[[[2.25, 0.75], + [0.75, 4.25]]]]).astype(np.float32) + net = ResizeBilinearGradNet() + axis = 3 + indices_dy = np.array([i for i in range(dy.shape[axis])]) + indices_x = np.array([i for i in range(x.shape[axis])]) + output = net(Tensor(dy), Tensor(x), Tensor(indices_dy), Tensor(indices_x), axis) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_resize_bilinear_grad_dyn_ascend(): + """ + Feature: Test ResizeBilinearGrad on Ascend. + Description: The shape of inputs is dynamic. + Expectation: Assert that results are consistent with expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + dyn_case() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_resize_bilinear_grad_dyn_gpu(): + """ + Feature: Test ResizeBilinearGrad on GPU. + Description: The shape of inputs is dynamic. + Expectation: Assert that results are consistent with expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dyn_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dyn_case() diff --git a/tests/st/ops/gpu/test_resize_bilinear_grad_op.py b/tests/st/ops/gpu/test_resize_bilinear_grad_op.py index 41b0933010f..74d929c6acf 100644 --- a/tests/st/ops/gpu/test_resize_bilinear_grad_op.py +++ b/tests/st/ops/gpu/test_resize_bilinear_grad_op.py @@ -23,9 +23,9 @@ from mindspore.ops.operations import _grad_ops as G class ResizeBilinearGradNet(nn.Cell): - def __init__(self, align_corners=False): + def __init__(self, align_corners=False, half_pixel_centers=False): super(ResizeBilinearGradNet, self).__init__() - self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners) + self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners, half_pixel_centers=half_pixel_centers) def construct(self, dy, size): return self.rb1(dy, size) @@ -91,3 +91,41 @@ def test_resize_bilinear_grad(): net = ResizeBilinearGradNet() output = net(Tensor(dy), Tensor(x)) assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_resize_bilinear_grad_half_pixel_centers(): + """ + Feature: Test ResizeBilinearGrad on GPU. + Description: The half_pixel_centers is True. + Expectation: Assert that results are consistent with expect. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) + + x = np.array([[[[1.1, 2.2, 3.2, 2.5], + [3.3, 4.4, 5.7, 8.1], + [3.3, 4.4, 5.7, 8.1], + [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float16) + expect = np.array([[[[0.25, 0.25, 0.5, 0.5], + [0.25, 0.25, 0.5, 0.5], + [0.75, 0.75, 1.0, 1.0], + [0.75, 0.75, 1.0, 1.0]]]], dtype=np.float16) + net = ResizeBilinearGradNet(half_pixel_centers=True) + output = net(Tensor(dy), Tensor(x)) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + + x = np.array([[[[1.1, 2.2, 3.2, 2.5], + [3.3, 4.4, 5.7, 8.1], + [3.3, 4.4, 5.7, 8.1], + [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float32) + expect = np.array([[[[0.25, 0.25, 0.5, 0.5], + [0.25, 0.25, 0.5, 0.5], + [0.75, 0.75, 1.0, 1.0], + [0.75, 0.75, 1.0, 1.0]]]], dtype=np.float32) + net = ResizeBilinearGradNet(half_pixel_centers=True) + output = net(Tensor(dy), Tensor(x)) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/gpu/test_resize_bilinear_op.py b/tests/st/ops/gpu/test_resize_bilinear_op.py index 08067356982..1f687c86f88 100644 --- a/tests/st/ops/gpu/test_resize_bilinear_op.py +++ b/tests/st/ops/gpu/test_resize_bilinear_op.py @@ -16,6 +16,7 @@ import numpy as np import pytest from mindspore import context, Tensor +import mindspore.ops as ops from mindspore.ops import operations as P from mindspore import nn @@ -570,3 +571,30 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32): diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) assert np.all(abs(diff_align) < error) + + +class NetResizeBilinearFunc(nn.Cell): + def construct(self, inputs, size, align_corner=False, half_pixel_centers=False): + return ops.resize_bilinear(inputs, size, align_corner, half_pixel_centers) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_resize_nn_func_half_pixel_centers(datatype=np.float32): + """ + Feature: Test resize_bilinear on GPU. + Description: The half_pixel_centers is True. + Expectation: Assert that results are consistent with expect. + """ + input_tensor = Tensor( + np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype)) + resize_nn_func = NetResizeBilinearFunc() + output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True) + expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288, + 0.36428574, 0.4], + [0.3, 0.3357143, 0.39285716, 0.45, 0.5071429, + 0.56428576, 0.6], + [0.5, 0.5357143, 0.5928572, 0.65, 0.7071429, + 0.76428574, 0.8]]]], dtype=datatype) + assert np.allclose(output.asnumpy(), expected_output)