modify resizebilineargrad input type

This commit is contained in:
simson 2021-09-02 16:11:31 +08:00
parent 50847c9659
commit f00e22342b
6 changed files with 101 additions and 120 deletions

View File

@ -49,7 +49,7 @@ __global__ void ResizeBilinear(const T *input, const int n, const int c, const i
}
// fp16 path
__global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w,
__global__ void ResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale,
const float w_scale, half *output, float *interim) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) {
@ -67,11 +67,11 @@ __global__ void ResizeBilinearGrad(const float *input, const int n, const int c,
const float w_beta = 1.0f - w_alpha;
const float h_alpha = posh_scaled - h_low;
const float h_beta = 1.0f - h_alpha;
const float grad = input[pos];
const float dp1 = h_beta * w_beta * grad;
const float dp2 = h_beta * w_alpha * grad;
const float dp3 = h_alpha * w_beta * grad;
const float dp4 = h_alpha * w_alpha * grad;
const half grad = input[pos];
const half dp1 = static_cast<half>(h_beta * w_beta) * grad;
const half dp2 = static_cast<half>(h_beta * w_alpha) * grad;
const half dp3 = static_cast<half>(h_alpha * w_beta) * grad;
const half dp4 = static_cast<half>(h_alpha * w_alpha) * grad;
const int output_start = output_h * output_w * (posn * c + posc);
atomicAdd(&interim[output_start + (h_low * output_w) + w_low], dp1);
atomicAdd(&interim[output_start + (h_low * output_w) + w_high], dp2);
@ -133,7 +133,7 @@ void CalResizeBilinear(const T *input, const int n, const int c, const int input
return;
}
void CalResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w,
void CalResizeBilinearGrad(const half *input, const int n, const int c, const int input_h, const int input_w,
const int output_h, const int output_w, const float h_scale, const float w_scale, half *output, float *interim,
cudaStream_t cuda_stream) {
const int hw = input_h * input_w;

View File

@ -21,7 +21,7 @@ template <typename T>
void CalResizeBilinear(const T *input, const int n_, const int c_, const int input_h_, const int input_w_,
const int output_h_, const int output_w_, const float h_scale, const float w_scale, T *output,
cudaStream_t cuda_stream);
void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_,
void CalResizeBilinearGrad(const half *input, const int n_, const int c_, const int input_h_, const int input_w_,
const int output_h_, const int output_w_, const float h_scale, const float w_scale, half *output, float *interim,
cudaStream_t cuda_stream);
void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_,

View File

@ -25,7 +25,7 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
ResizeBilinearGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeBilinearGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -36,7 +36,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
float *dy = GetDeviceAddress<float>(inputs, 0);
T *dy = GetDeviceAddress<T>(inputs, 0);
float *interim = GetDeviceAddress<float>(workspace, 0);
T *dx = GetDeviceAddress<T>(outputs, 0);
float h_scale = Scaling(dx_h_, dy_h_, align_corners_);
@ -81,7 +81,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel {
dy_w_ = SizeToInt(dy_shape[3]);
dx_h_ = SizeToInt(dx_shape[2]);
dx_w_ = SizeToInt(dx_shape[3]);
dy_size_ = sizeof(float);
dy_size_ = sizeof(T);
for (auto x : dy_shape) {
dy_size_ *= x;
}
@ -89,7 +89,7 @@ class ResizeBilinearGradGpuKernel : public GpuKernel {
for (auto x : dx_shape) {
dx_size_ *= x;
}
workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(float);
workspace_size_ = (dx_size_ / sizeof(T)) * sizeof(T);
align_corners_ = GetAttr<bool>(kernel_node, "align_corners");
InitSizeLists();
return true;

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -37,24 +37,16 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
# larger h and w
resize_nn = NetResizeBilinear((9, 9))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.13330078, 0.16662598, 0.19995117, 0.23331706,
0.26668295, 0.30004883, 0.30004883, 0.30004883],
[0.19995117, 0.23328993, 0.26662868, 0.29996747, 0.33333334,
0.36669925, 0.40006512, 0.40006512, 0.40006512],
[0.29992676, 0.33327907, 0.36663142, 0.39998373, 0.4333496,
0.4667155, 0.5000814, 0.5000814, 0.5000814],
[0.39990234, 0.43326822, 0.46663412, 0.5, 0.5333659,
0.5667318, 0.60009766, 0.60009766, 0.60009766],
[0.5, 0.5333116, 0.5666233, 0.59993494, 0.6333008,
0.66666675, 0.7000326, 0.7000326, 0.7000326],
[0.60009766, 0.633355, 0.66661245, 0.6998698, 0.7332357,
0.7666016, 0.79996747, 0.79996747, 0.79996747],
[0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706,
0.8665365, 0.89990234, 0.89990234, 0.89990234],
[0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706,
0.8665365, 0.89990234, 0.89990234, 0.89990234],
[0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706,
0.8665365, 0.89990234, 0.89990234, 0.89990234]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1333, 0.1666, 0.2, 0.2333, 0.2666, 0.3, 0.3, 0.3],
[0.2, 0.2333, 0.2666, 0.2998, 0.3333, 0.3667, 0.4, 0.4, 0.4],
[0.2998, 0.3333, 0.3665, 0.4, 0.433, 0.4668, 0.5, 0.5, 0.5],
[0.4, 0.4333, 0.4666, 0.5, 0.533, 0.567, 0.6, 0.6, 0.6],
[0.5, 0.533, 0.5664, 0.6, 0.6333, 0.667, 0.7, 0.7, 0.7],
[0.6, 0.6333, 0.6665, 0.6997, 0.733, 0.7666, 0.8, 0.8, 0.8],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9],
[0.7, 0.7334, 0.7666, 0.8, 0.833, 0.8667, 0.9, 0.9, 0.9]]]]
).astype(np.float16))
error = np.ones(shape=[9, 9]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -62,7 +54,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
# smaller h and w
resize_nn = NetResizeBilinear((1, 1))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1]]]]).astype(np.float16))
error = np.ones(shape=[1, 1]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -71,7 +63,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((1, 6))
output = resize_nn(input_tensor)
expected_output = Tensor(
np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883]]]]).astype(np.float32))
np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3]]]]).astype(np.float16))
error = np.ones(shape=[1, 6]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -80,8 +72,12 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((6, 1))
output = resize_nn(input_tensor)
expected_output = Tensor(
np.array([[[[0.09997559], [0.24993896], [0.39990234], [0.5500488], [0.7001953], [0.7001953]]]]).astype(
np.float32))
np.array([[[[0.1],
[0.2499],
[0.4],
[0.55],
[0.7],
[0.7]]]]).astype(np.float16))
error = np.ones(shape=[6, 1]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -90,7 +86,7 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((1, 3))
output = resize_nn(input_tensor)
expected_output = Tensor(
np.array([[[[0.09997559, 0.19995117, 0.30004883]]]]).astype(np.float32))
np.array([[[[0.1, 0.2, 0.3]]]]).astype(np.float16))
error = np.ones(shape=[1, 3]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -98,12 +94,12 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
# larger h, same w
resize_nn = NetResizeBilinear((6, 3))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883],
[0.24993896, 0.3499756, 0.45007324],
[0.39990234, 0.5, 0.60009766],
[0.5500488, 0.64990234, 0.75],
[0.7001953, 0.7998047, 0.89990234],
[0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3],
[0.2499, 0.35, 0.4502],
[0.4, 0.5, 0.6],
[0.55, 0.65, 0.75],
[0.7, 0.8, 0.9],
[0.7, 0.8, 0.9]]]]).astype(np.float16))
error = np.ones(shape=[6, 3]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -112,7 +108,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((3, 1))
output = resize_nn(input_tensor)
expected_output = Tensor(
np.array([[[[0.09997559], [0.39990234], [0.7001953]]]]).astype(np.float32))
np.array([[[[0.1],
[0.4],
[0.7]]]]).astype(np.float16))
error = np.ones(shape=[3, 1]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -120,12 +118,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
# same h, larger w
resize_nn = NetResizeBilinear((3, 6))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883,
0.30004883],
[0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766,
0.60009766],
[0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234,
0.89990234]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3],
[0.4, 0.45, 0.5, 0.55, 0.6, 0.6],
[0.7, 0.75, 0.8, 0.8496, 0.9, 0.9]]]]).astype(np.float16))
error = np.ones(shape=[3, 6]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -134,9 +129,9 @@ def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((3, 3))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array(
[[[[0.09997559, 0.19995117, 0.30004883],
[0.39990234, 0.5, 0.60009766],
[0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32))
[[[[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9]]]]).astype(np.float16))
error = np.ones(shape=[3, 3]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -251,20 +246,13 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# larger h and w
resize_nn = NetResizeBilinear((7, 7))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784,
0.38563755, 0.39990234],
[0.27141464, 0.3285734, 0.3857422, 0.44294086, 0.5000399,
0.55703926, 0.57128906],
[0.44285366, 0.5000423, 0.5572336, 0.6144322, 0.67150134,
0.7284409, 0.7426758],
[0.6142578, 0.50819117, 0.44293588, 0.5001146, 0.5571937,
0.6141731, 0.62841797],
[0.78564453, 0.4346799, 0.18574369, 0.2428925, 0.3000015,
0.3570706, 0.3713379],
[0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005,
0.18566895, 0.19995117],
[0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005,
0.18566895, 0.19995117]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4],
[0.2715, 0.3286, 0.3857, 0.4429, 0.5, 0.557, 0.5713],
[0.4429, 0.5, 0.557, 0.6143, 0.6714, 0.7285, 0.7427],
[0.6143, 0.5083, 0.4429, 0.5005, 0.557, 0.6143, 0.6284],
[0.7856, 0.4346, 0.1855, 0.2429, 0.2998, 0.357, 0.3716],
[0.9, 0.3857, 0.01428, 0.0714, 0.1285, 0.1857, 0.2],
[0.9, 0.3857, 0.01428, 0.0714, 0.1285, 0.1857, 0.2]]]]).astype(np.float16))
error = np.ones(shape=[7, 7]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -273,8 +261,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
resize_nn = NetResizeBilinear((2, 3))
output = resize_nn(input_tensor)
expected_output = Tensor(
np.array([[[[0.09997559, 0.23331706, 0.36661786],
[0.6999512, 0.33339438, 0.46661377]]]]).astype(np.float32))
np.array([[[[0.1, 0.2333, 0.3667],
[0.7, 0.3333, 0.4666]]]]).astype(np.float16))
error = np.ones(shape=[2, 3]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -282,10 +270,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# smaller h, larger w
resize_nn = NetResizeBilinear((2, 7))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784,
0.38563755, 0.39990234],
[0.6999512, 0.47143552, 0.3143398, 0.37150356, 0.4285976,
0.48562187, 0.49987793]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4],
[0.7, 0.4714, 0.3142, 0.3716, 0.4285, 0.4856, 0.5]]]]).astype(np.float16))
error = np.ones(shape=[2, 7]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -293,11 +279,11 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# larger h, smaller w
resize_nn = NetResizeBilinear((5, 3))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.23331706, 0.36661786],
[0.33999026, 0.47340494, 0.6066081],
[0.5799805, 0.51343584, 0.64660645],
[0.8199219, 0.15335283, 0.28662106],
[0.89990234, 0.0333252, 0.16662598]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.2333, 0.3667],
[0.3398, 0.4731, 0.6064],
[0.58, 0.513, 0.6465],
[0.82, 0.1533, 0.2866],
[0.9, 0.03333, 0.1666]]]]).astype(np.float16))
error = np.ones(shape=[5, 3]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -305,8 +291,8 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# smaller h, same w
resize_nn = NetResizeBilinear((2, 4))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234],
[0.6999512, 0.30004883, 0.40008545, 0.49987793]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
[0.7, 0.3, 0.4001, 0.5]]]]).astype(np.float16))
error = np.ones(shape=[2, 4]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -314,14 +300,14 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# larger h, same w
resize_nn = NetResizeBilinear((8, 4))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234],
[0.24998474, 0.3500061, 0.45010376, 0.5498657],
[0.3999939, 0.50006104, 0.6001587, 0.6998291],
[0.5499878, 0.52508545, 0.62516785, 0.724823],
[0.6999512, 0.30004883, 0.40008545, 0.49987793],
[0.84991455, 0.07501221, 0.17500305, 0.27493286],
[0.89990234, 0., 0.09997559, 0.19995117],
[0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
[0.2499, 0.35, 0.4502, 0.55],
[0.4, 0.5, 0.6, 0.6997],
[0.55, 0.525, 0.625, 0.7246],
[0.7, 0.3, 0.4001, 0.5],
[0.8496, 0.0752, 0.1753, 0.2754],
[0.9, 0., 0.1, 0.2],
[0.9, 0., 0.1, 0.2]]]]).astype(np.float16))
error = np.ones(shape=[8, 4]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -329,9 +315,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# same h, smaller w
resize_nn = NetResizeBilinear((3, 2))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.30004883],
[0.5, 0.7001953],
[0.89990234, 0.09997559]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.3],
[0.5, 0.7],
[0.9, 0.1]]]]).astype(np.float16))
error = np.ones(shape=[3, 2]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -339,12 +325,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# same h, larger w
resize_nn = NetResizeBilinear((3, 6))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.16662598, 0.23331706, 0.30004883, 0.36661786,
0.39990234],
[0.5, 0.56673175, 0.63346356, 0.7001953, 0.76660156,
0.7998047],
[0.89990234, 0.2999674, 0.0333252, 0.09997559, 0.16662598,
0.19995117]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1666, 0.2333, 0.3, 0.3667, 0.4],
[0.5, 0.567, 0.6333, 0.7, 0.7666, 0.8],
[0.9, 0.3003, 0.03333, 0.1, 0.1666, 0.2]]]]).astype(np.float16))
error = np.ones(shape=[3, 6]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -352,9 +335,9 @@ def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16):
# same w, same h (identity)
resize_nn = NetResizeBilinear((3, 4))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234],
[0.5, 0.60009766, 0.7001953, 0.7998047],
[0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 0., 0.1, 0.2]]]]).astype(np.float16))
error = np.ones(shape=[3, 4]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
@ -476,13 +459,12 @@ def test_resize_nn_grayscale_multiple_images_half(datatype=np.float16):
resize_nn = NetResizeBilinear((2, 6))
output = resize_nn(input_tensor)
expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883],
[0.5500488, 0.5999756, 0.64990234, 0.6999512, 0.75, 0.75]]],
[[[0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766, 0.60009766],
[0.40008545, 0.4499817, 0.49987793, 0.54992676, 0.5999756, 0.5999756]]],
[[[0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234, 0.89990234],
[0.24993896, 0.29995728, 0.3499756, 0.4000244, 0.45007324,
0.45007324]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.3],
[0.55, 0.6, 0.65, 0.6997, 0.75, 0.75]]],
[[[0.4, 0.45, 0.5, 0.55, 0.6, 0.6],
[0.4001, 0.45, 0.5, 0.55, 0.6, 0.6]]],
[[[0.7, 0.75, 0.8, 0.8496, 0.9, 0.9],
[0.2499, 0.2998, 0.35, 0.4, 0.4502, 0.4502]]]]).astype(np.float16))
error = np.ones(shape=[3, 3, 2, 6]) * 1.0e-6
diff = output.asnumpy() - expected_output.asnumpy()
@ -520,18 +502,12 @@ def test_resize_nn_grayscale_align_corners_half(datatype=np.float16):
resize_nn = NetResizeBilinear((3, 7))
output = resize_nn(input_tensor)
expected_output_align = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883,
0.3499756, 0.39990234],
[0.2999878, 0.3500061, 0.4000244, 0.45007324, 0.5001221,
0.5499878, 0.5998535],
[0.5, 0.5500488, 0.60009766, 0.6501465, 0.7001953,
0.75, 0.7998047]]]]).astype(np.float32))
expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784,
0.38563755, 0.39990234],
[0.36665854, 0.42383394, 0.4810152, 0.53821385, 0.59529626,
0.6522624, 0.6665039],
[0.5, 0.55719864, 0.61439735, 0.671596, 0.72865516,
0.7855748, 0.7998047]]]]).astype(np.float32))
expected_output_align = Tensor(np.array([[[[0.1, 0.1499, 0.2, 0.25, 0.3, 0.35, 0.4],
[0.2998, 0.3499, 0.4, 0.4502, 0.5, 0.55, 0.5996],
[0.5, 0.55, 0.6, 0.6504, 0.7, 0.75, 0.8]]]]).astype(np.float16))
expected_output = Tensor(np.array([[[[0.1, 0.1571, 0.2142, 0.2715, 0.3286, 0.3857, 0.4],
[0.3667, 0.4238, 0.481, 0.538, 0.595, 0.6523, 0.6665],
[0.5, 0.557, 0.6143, 0.672, 0.7285, 0.7856, 0.8]]]]).astype(np.float16))
error = np.ones(shape=[3, 7]) * 1.0e-6
diff_align = output_corners_aligned.asnumpy() - expected_output_align.asnumpy()

View File

@ -36,7 +36,7 @@ class ResizeBilinearGradNet(nn.Cell):
@pytest.mark.env_onecard
def test_resize_bilinear_grad_align_corners():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16)
x = np.array([[[[1.1, 2.2, 3.2, 2.5],
[3.3, 4.4, 5.7, 8.1],
@ -49,6 +49,7 @@ def test_resize_bilinear_grad_align_corners():
net = ResizeBilinearGradNet(align_corners=True)
output = net(Tensor(dy), Tensor(x))
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
x = np.array([[[[1.1, 2.2, 3.2, 2.5],
[3.3, 4.4, 5.7, 8.1],
@ -71,7 +72,7 @@ def test_resize_bilinear_grad():
dy = np.array([[[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1]]]]).astype(np.float32)
[0, 0, 1, 1]]]]).astype(np.float16)
x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float16)
expect = np.array([[[[2.25, 0.75],
@ -80,6 +81,10 @@ def test_resize_bilinear_grad():
output = net(Tensor(dy), Tensor(x))
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 1],
[0, 0, 1, 1]]]]).astype(np.float32)
x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32)
expect = np.array([[[[2.25, 0.75],
[0.75, 4.25]]]]).astype(np.float32)