diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cu index 898125f5ec2..02d23d25a2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cu @@ -49,7 +49,7 @@ __global__ void ResizeBilinear(const T *input, const int n, const int c, const i } // fp16 path -__global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, +__global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w, const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale, const float w_scale, half *output, float *interim) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { @@ -67,11 +67,11 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const float w_beta = 1.0f - w_alpha; const float h_alpha = posh_scaled - h_low; const float h_beta = 1.0f - h_alpha; - const float grad = input[pos]; - const float dp1 = h_beta * w_beta * grad; - const float dp2 = h_beta * w_alpha * grad; - const float dp3 = h_alpha * w_beta * grad; - const float dp4 = h_alpha * w_alpha * grad; + const half grad = input[pos]; + const half dp1 = static_cast(h_beta * w_beta) * grad; + const half dp2 = static_cast(h_beta * w_alpha) * grad; + const half dp3 = static_cast(h_alpha * w_beta) * grad; + const half dp4 = static_cast(h_alpha * w_alpha) * grad; const int output_start = output_h * output_w * (posn * c + posc); atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1); atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2); @@ -133,7 +133,7 @@ void CalResizeBilinear(const T *input, const int n, const int c, const int input return; } -void CalResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, +void CalResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w, const int output_h, const int output_w, const float h_scale, const float w_scale, half *output, float *interim, cudaStream_t cuda_stream) { const int hw = input_h * input_w; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh index e50641d3824..e175c4f9cdf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh @@ -21,7 +21,7 @@ template void CalResizeBilinear(const T *input, const int n_, const int c_, const int input_h_, const int input_w_, const int output_h_, const int output_w_, const float h_scale, const float w_scale, T *output, cudaStream_t cuda_stream); -void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_, +void CalResizeBilinearGrad(const half *input, const int n_, const int c_, const int input_h_, const int input_w_, const int output_h_, const int output_w_, const float h_scale, const float w_scale, half *output, float *interim, cudaStream_t cuda_stream); void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.cc index e75e22bd468..de4349de080 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.cc @@ -25,7 +25,7 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( ResizeBilinearGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ResizeBilinearGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h index a846d5419b1..f95597a3278 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h @@ -36,7 +36,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - float *dy = GetDeviceAddress(inputs, 0); + T *dy = GetDeviceAddress(inputs, 0); float *interim = GetDeviceAddress(workspace, 0); T *dx = GetDeviceAddress(outputs, 0); float h_scale = Scaling(dx_h_, dy_h_, align_corners_); @@ -81,7 +81,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel { dy_w_ = SizeToInt(dy_shape[3]); dx_h_ = SizeToInt(dx_shape[2]); dx_w_ = SizeToInt(dx_shape[3]); - dy_size_ = sizeof(float); + dy_size_ = sizeof(T); for (auto x : dy_shape) { dy_size_ *= x; } @@ -89,7 +89,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel { for (auto x : dx_shape) { dx_size_ *= x; } - workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(float); + workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(T); align_corners_ = GetAttr(kernel_node, "align_corners"); InitSizeLists(); return true; diff --git a/tests/st/ops/cpu/test_resize_bilinear_op.py b/tests/st/ops/cpu/test_resize_bilinear_op.py index dab90236b57..4c22152568c 100644 --- a/tests/st/ops/cpu/test_resize_bilinear_op.py +++ b/tests/st/ops/cpu/test_resize_bilinear_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,24 +37,16 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): # larger h and w resize_nn = NetResizeBilinear((9, 9)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.13330078, 0.16662598, 0.19995117, 0.23331706, - 0.26668295, 0.30004883, 0.30004883, 0.30004883], - [0.19995117, 0.23328993, 0.26662868, 0.29996747, 0.33333334, - 0.36669925, 0.40006512, 0.40006512, 0.40006512], - [0.29992676, 0.33327907, 0.36663142, 0.39998373, 0.4333496, - 0.4667155, 0.5000814, 0.5000814, 0.5000814], - [0.39990234, 0.43326822, 0.46663412, 0.5, 0.5333659, - 0.5667318, 0.60009766, 0.60009766, 0.60009766], - [0.5, 0.5333116, 0.5666233, 0.59993494, 0.6333008, - 0.66666675, 0.7000326, 0.7000326, 0.7000326], - [0.60009766, 0.633355, 0.66661245, 0.6998698, 0.7332357, - 0.7666016, 0.79996747, 0.79996747, 0.79996747], - [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, - 0.8665365, 0.89990234, 0.89990234, 0.89990234], - [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, - 0.8665365, 0.89990234, 0.89990234, 0.89990234], - [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, - 0.8665365, 0.89990234, 0.89990234, 0.89990234]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3], + [0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4], + [0.2998, 0.3333, 0.3665, 0.4, 0.433, 0.4668, 0.5, 0.5, 0.5], + [0.4, 0.4333, 0.4666, 0.5, 0.533, 0.567, 0.6, 0.6, 0.6], + [0.5, 0.533, 0.5664, 0.6, 0.6333, 0.667, 0.7, 0.7, 0.7], + [0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.7666, 0.8, 0.8, 0.8], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9], + [0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9]]]] + ).astype(np.float16)) error = np.ones(shape=[9, 9]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -62,7 +54,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): # smaller h and w resize_nn = NetResizeBilinear((1, 1)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1]]]]).astype(np.float16)) error = np.ones(shape=[1, 1]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -71,7 +63,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((1, 6)) output = resize_nn(input_tensor) expected_output = Tensor( - np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883]]]]).astype(np.float32)) + np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3]]]]).astype(np.float16)) error = np.ones(shape=[1, 6]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -80,8 +72,12 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((6, 1)) output = resize_nn(input_tensor) expected_output = Tensor( - np.array([[[[0.09997559], [0.24993896], [0.39990234], [0.5500488], [0.7001953], [0.7001953]]]]).astype( - np.float32)) + np.array([[[[0.1], + [0.2499], + [0.4], + [0.55], + [0.7], + [0.7]]]]).astype(np.float16)) error = np.ones(shape=[6, 1]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -90,7 +86,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((1, 3)) output = resize_nn(input_tensor) expected_output = Tensor( - np.array([[[[0.09997559, 0.19995117, 0.30004883]]]]).astype(np.float32)) + np.array([[[[0.1, 0.2, 0.3]]]]).astype(np.float16)) error = np.ones(shape=[1, 3]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -98,12 +94,12 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): # larger h, same w resize_nn = NetResizeBilinear((6, 3)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883], - [0.24993896, 0.3499756, 0.45007324], - [0.39990234, 0.5, 0.60009766], - [0.5500488, 0.64990234, 0.75], - [0.7001953, 0.7998047, 0.89990234], - [0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3], + [0.2499, 0.35, 0.4502], + [0.4, 0.5, 0.6], + [0.55, 0.65, 0.75], + [0.7, 0.8, 0.9], + [0.7, 0.8, 0.9]]]]).astype(np.float16)) error = np.ones(shape=[6, 3]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -112,7 +108,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((3, 1)) output = resize_nn(input_tensor) expected_output = Tensor( - np.array([[[[0.09997559], [0.39990234], [0.7001953]]]]).astype(np.float32)) + np.array([[[[0.1], + [0.4], + [0.7]]]]).astype(np.float16)) error = np.ones(shape=[3, 1]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -120,12 +118,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): # same h, larger w resize_nn = NetResizeBilinear((3, 6)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, - 0.30004883], - [0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766, - 0.60009766], - [0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234, - 0.89990234]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3], + [0.4, 0.45, 0.5, 0.55, 0.6, 0.6], + [0.7, 0.75, 0.8, 0.8496, 0.9, 0.9]]]]).astype(np.float16)) error = np.ones(shape=[3, 6]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -134,9 +129,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((3, 3)) output = resize_nn(input_tensor) expected_output = Tensor(np.array( - [[[[0.09997559, 0.19995117, 0.30004883], - [0.39990234, 0.5, 0.60009766], - [0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32)) + [[[[0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9]]]]).astype(np.float16)) error = np.ones(shape=[3, 3]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -251,20 +246,13 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # larger h and w resize_nn = NetResizeBilinear((7, 7)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, - 0.38563755, 0.39990234], - [0.27141464, 0.3285734, 0.3857422, 0.44294086, 0.5000399, - 0.55703926, 0.57128906], - [0.44285366, 0.5000423, 0.5572336, 0.6144322, 0.67150134, - 0.7284409, 0.7426758], - [0.6142578, 0.50819117, 0.44293588, 0.5001146, 0.5571937, - 0.6141731, 0.62841797], - [0.78564453, 0.4346799, 0.18574369, 0.2428925, 0.3000015, - 0.3570706, 0.3713379], - [0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005, - 0.18566895, 0.19995117], - [0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005, - 0.18566895, 0.19995117]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4], + [0.2715, 0.3286, 0.3857, 0.4429, 0.5, 0.557, 0.5713], + [0.4429, 0.5, 0.557, 0.6143, 0.6714, 0.7285, 0.7427], + [0.6143, 0.5083, 0.4429, 0.5005, 0.557, 0.6143, 0.6284], + [0.7856, 0.4346, 0.1855, 0.2429, 0.2998, 0.357, 0.3716], + [0.9, 0.3857, 0.01428, 0.0714, 0.1285, 0.1857, 0.2], + [0.9, 0.3857, 0.01428, 0.0714, 0.1285, 0.1857, 0.2]]]]).astype(np.float16)) error = np.ones(shape=[7, 7]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -273,8 +261,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): resize_nn = NetResizeBilinear((2, 3)) output = resize_nn(input_tensor) expected_output = Tensor( - np.array([[[[0.09997559, 0.23331706, 0.36661786], - [0.6999512, 0.33339438, 0.46661377]]]]).astype(np.float32)) + np.array([[[[0.1, 0.2333, 0.3667], + [0.7, 0.3333, 0.4666]]]]).astype(np.float16)) error = np.ones(shape=[2, 3]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -282,10 +270,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # smaller h, larger w resize_nn = NetResizeBilinear((2, 7)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, - 0.38563755, 0.39990234], - [0.6999512, 0.47143552, 0.3143398, 0.37150356, 0.4285976, - 0.48562187, 0.49987793]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4], + [0.7, 0.4714, 0.3142, 0.3716, 0.4285, 0.4856, 0.5]]]]).astype(np.float16)) error = np.ones(shape=[2, 7]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -293,11 +279,11 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # larger h, smaller w resize_nn = NetResizeBilinear((5, 3)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.23331706, 0.36661786], - [0.33999026, 0.47340494, 0.6066081], - [0.5799805, 0.51343584, 0.64660645], - [0.8199219, 0.15335283, 0.28662106], - [0.89990234, 0.0333252, 0.16662598]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.2333, 0.3667], + [0.3398, 0.4731, 0.6064], + [0.58, 0.513, 0.6465], + [0.82, 0.1533, 0.2866], + [0.9, 0.03333, 0.1666]]]]).astype(np.float16)) error = np.ones(shape=[5, 3]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -305,8 +291,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # smaller h, same w resize_nn = NetResizeBilinear((2, 4)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], - [0.6999512, 0.30004883, 0.40008545, 0.49987793]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.7, 0.3, 0.4001, 0.5]]]]).astype(np.float16)) error = np.ones(shape=[2, 4]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -314,14 +300,14 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # larger h, same w resize_nn = NetResizeBilinear((8, 4)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], - [0.24998474, 0.3500061, 0.45010376, 0.5498657], - [0.3999939, 0.50006104, 0.6001587, 0.6998291], - [0.5499878, 0.52508545, 0.62516785, 0.724823], - [0.6999512, 0.30004883, 0.40008545, 0.49987793], - [0.84991455, 0.07501221, 0.17500305, 0.27493286], - [0.89990234, 0., 0.09997559, 0.19995117], - [0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.2499, 0.35, 0.4502, 0.55], + [0.4, 0.5, 0.6, 0.6997], + [0.55, 0.525, 0.625, 0.7246], + [0.7, 0.3, 0.4001, 0.5], + [0.8496, 0.0752, 0.1753, 0.2754], + [0.9, 0., 0.1, 0.2], + [0.9, 0., 0.1, 0.2]]]]).astype(np.float16)) error = np.ones(shape=[8, 4]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -329,9 +315,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # same h, smaller w resize_nn = NetResizeBilinear((3, 2)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.30004883], - [0.5, 0.7001953], - [0.89990234, 0.09997559]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.3], + [0.5, 0.7], + [0.9, 0.1]]]]).astype(np.float16)) error = np.ones(shape=[3, 2]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -339,12 +325,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # same h, larger w resize_nn = NetResizeBilinear((3, 6)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.16662598, 0.23331706, 0.30004883, 0.36661786, - 0.39990234], - [0.5, 0.56673175, 0.63346356, 0.7001953, 0.76660156, - 0.7998047], - [0.89990234, 0.2999674, 0.0333252, 0.09997559, 0.16662598, - 0.19995117]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1666, 0.2333, 0.3, 0.3667, 0.4], + [0.5, 0.567, 0.6333, 0.7, 0.7666, 0.8], + [0.9, 0.3003, 0.03333, 0.1, 0.1666, 0.2]]]]).astype(np.float16)) error = np.ones(shape=[3, 6]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -352,9 +335,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): # same w, same h (identity) resize_nn = NetResizeBilinear((3, 4)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], - [0.5, 0.60009766, 0.7001953, 0.7998047], - [0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0., 0.1, 0.2]]]]).astype(np.float16)) error = np.ones(shape=[3, 4]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() assert np.all(abs(diff) < error) @@ -476,13 +459,12 @@ def test_resize_nn_grayscale_multiple_images_half(datatype=np.float16): resize_nn = NetResizeBilinear((2, 6)) output = resize_nn(input_tensor) - expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883], - [0.5500488, 0.5999756, 0.64990234, 0.6999512, 0.75, 0.75]]], - [[[0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766, 0.60009766], - [0.40008545, 0.4499817, 0.49987793, 0.54992676, 0.5999756, 0.5999756]]], - [[[0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234, 0.89990234], - [0.24993896, 0.29995728, 0.3499756, 0.4000244, 0.45007324, - 0.45007324]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3], + [0.55, 0.6, 0.65, 0.6997, 0.75, 0.75]]], + [[[0.4, 0.45, 0.5, 0.55, 0.6, 0.6], + [0.4001, 0.45, 0.5, 0.55, 0.6, 0.6]]], + [[[0.7, 0.75, 0.8, 0.8496, 0.9, 0.9], + [0.2499, 0.2998, 0.35, 0.4, 0.4502, 0.4502]]]]).astype(np.float16)) error = np.ones(shape=[3, 3, 2, 6]) * 1.0e-6 diff = output.asnumpy() - expected_output.asnumpy() @@ -520,18 +502,12 @@ def test_resize_nn_grayscale_align_corners_half(datatype=np.float16): resize_nn = NetResizeBilinear((3, 7)) output = resize_nn(input_tensor) - expected_output_align = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, - 0.3499756, 0.39990234], - [0.2999878, 0.3500061, 0.4000244, 0.45007324, 0.5001221, - 0.5499878, 0.5998535], - [0.5, 0.5500488, 0.60009766, 0.6501465, 0.7001953, - 0.75, 0.7998047]]]]).astype(np.float32)) - expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, - 0.38563755, 0.39990234], - [0.36665854, 0.42383394, 0.4810152, 0.53821385, 0.59529626, - 0.6522624, 0.6665039], - [0.5, 0.55719864, 0.61439735, 0.671596, 0.72865516, - 0.7855748, 0.7998047]]]]).astype(np.float32)) + expected_output_align = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.35, 0.4], + [0.2998, 0.3499, 0.4, 0.4502, 0.5, 0.55, 0.5996], + [0.5, 0.55, 0.6, 0.6504, 0.7, 0.75, 0.8]]]]).astype(np.float16)) + expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4], + [0.3667, 0.4238, 0.481, 0.538, 0.595, 0.6523, 0.6665], + [0.5, 0.557, 0.6143, 0.672, 0.7285, 0.7856, 0.8]]]]).astype(np.float16)) error = np.ones(shape=[3, 7]) * 1.0e-6 diff_align = output_corners_aligned.asnumpy() - expected_output_align.asnumpy() diff --git a/tests/st/ops/gpu/test_resize_bilinear_grad_op.py b/tests/st/ops/gpu/test_resize_bilinear_grad_op.py index 39c9965275e..6319e246372 100644 --- a/tests/st/ops/gpu/test_resize_bilinear_grad_op.py +++ b/tests/st/ops/gpu/test_resize_bilinear_grad_op.py @@ -36,7 +36,7 @@ class ResizeBilinearGradNet(nn.Cell): @pytest.mark.env_onecard def test_resize_bilinear_grad_align_corners(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) x = np.array([[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], @@ -49,6 +49,7 @@ def test_resize_bilinear_grad_align_corners(): net = ResizeBilinearGradNet(align_corners=True) output = net(Tensor(dy), Tensor(x)) assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) x = np.array([[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], @@ -71,7 +72,7 @@ def test_resize_bilinear_grad(): dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], - [0, 0, 1, 1]]]]).astype(np.float32) + [0, 0, 1, 1]]]]).astype(np.float16) x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float16) expect = np.array([[[[2.25, 0.75], @@ -80,6 +81,10 @@ def test_resize_bilinear_grad(): output = net(Tensor(dy), Tensor(x)) assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]]]).astype(np.float32) x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32) expect = np.array([[[[2.25, 0.75], [0.75, 4.25]]]]).astype(np.float32)