From f4ae9cdbc48b7c736c880b472eceb4ea84c8f4e9 Mon Sep 17 00:00:00 2001 From: hebotao Date: Fri, 11 Dec 2020 17:28:16 +0800 Subject: [PATCH] Add SquareDifference kernel for GPU MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加测试用例 添加测试用例 添加测试用例 添加测试用例 清除告警 添加测试用例 添加测试用例 清除告警 --- .../kernel_compiler/cpu/random_cpu_kernel.cc | 1 - .../kernel_compiler/cpu/select_cpu_kernel.cc | 1 - .../gpu/cuda_impl/broadcast_impl.cu | 15 + .../gpu/cuda_impl/broadcast_impl.cuh | 1 + .../gpu/math/squared_difference_kernel.cc | 40 +++ .../gpu/math/squared_difference_kernel.h | 144 ++++++++ .../st/ops/gpu/test_squared_difference_op.py | 314 ++++++++++++++++++ 7 files changed, 514 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h create mode 100644 tests/st/ops/gpu/test_squared_difference_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc index 461833e2a48..542049126b0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/random_cpu_kernel.cc @@ -105,6 +105,5 @@ bool RandomCPUKernel::Launch(const std::vector &inputs, } return true; } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc index c61cbf2d09c..470fb429556 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/select_cpu_kernel.cc @@ -48,6 +48,5 @@ bool SelectCPUKernel::Launch(const std::vector &inputs, const std } return true; } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 5277bb9f7c2..f8080eccf19 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -192,6 +192,14 @@ struct AbsGradFunc { } }; +template +struct SquaredDifferenceFunc { + __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) { + T diff = lhs - rhs; + return diff * diff; + } +}; + // Element-wise Comparation template __global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) { @@ -260,6 +268,8 @@ void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, c return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); case BROADCAST_TYPE_DIVNONAN: return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_SQUARED_DIFFERENCE: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); default: break; } @@ -481,6 +491,11 @@ void BroadcastArith(const std::vector &x0_dims, const std::vector><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); default: break; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index c72c144ef58..f3d0fa51055 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -37,6 +37,7 @@ enum BroadcastOpType { BROADCAST_TYPE_DIV = 11, BROADCAST_TYPE_DIVNONAN = 12, BROADCAST_TYPE_EQUAL = 13, + BROADCAST_TYPE_SQUARED_DIFFERENCE = 14, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.cc new file mode 100644 index 00000000000..2395d2e1c78 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/squared_difference_kernel.h" + +namespace mindspore { +namespace kernel { +// fp32 +MS_REG_GPU_KERNEL_ONE( + SquaredDifference, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SquaredDifferenceOpGpuKernel, float) + +// fp16 +MS_REG_GPU_KERNEL_ONE( + SquaredDifference, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SquaredDifferenceOpGpuKernel, half) + +// int32 +MS_REG_GPU_KERNEL_ONE( + SquaredDifference, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SquaredDifferenceOpGpuKernel, int) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h new file mode 100644 index 00000000000..1eb201a797d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/squared_difference_kernel.h @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SQUARED_DIFFERENCE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SQUARED_DIFFERENCE_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +constexpr int MAX_DIMS = 7; +template +class SquaredDifferenceOpGpuKernel : public GpuKernel { + public: + SquaredDifferenceOpGpuKernel() { ResetResource(); } + ~SquaredDifferenceOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *lhs = GetDeviceAddress(inputs, 0); + T *rhs = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + if (need_broadcast_) { + BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, + reinterpret_cast(stream_ptr)); + } else { + ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape1 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); + auto input_shape2 = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); + need_broadcast_ = IsBroadcast(input_shape1, input_shape2); + if (need_broadcast_ && output_shape.size() > MAX_DIMS) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; + } + + lhs_shape_.resize(MAX_DIMS, 1); + rhs_shape_.resize(MAX_DIMS, 1); + output_shape_.resize(MAX_DIMS, 1); + for (size_t i = 0; i < output_shape.size(); i++) { + if (need_broadcast_) { + output_shape_[i] = output_shape[i]; + } + output_num_ *= output_shape[i]; + } + int lhs_offset = output_shape.size() - input_shape1.size(); + for (size_t j = 0; j < input_shape1.size(); j++) { + if (need_broadcast_) { + lhs_shape_[j + lhs_offset] = input_shape1[j]; + } + input1_num_ *= input_shape1[j]; + } + int rhs_offset = output_shape.size() - input_shape2.size(); + for (size_t k = 0; k < input_shape2.size(); k++) { + if (need_broadcast_) { + rhs_shape_[k + rhs_offset] = input_shape2[k]; + } + input2_num_ *= input_shape2[k]; + } + + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + op_type_ = BROADCAST_TYPE_SQUARED_DIFFERENCE; + need_broadcast_ = false; + input1_num_ = 1; + input2_num_ = 1; + output_num_ = 1; + lhs_shape_.clear(); + rhs_shape_.clear(); + output_shape_.clear(); + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + output_size_list_.push_back(output_num_ * sizeof(T)); + } + + private: + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastOpType op_type_; + bool need_broadcast_; + bool is_comp_op_; + size_t input1_num_; + size_t input2_num_; + size_t output_num_; + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; // namespace kernel +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SQUARED_DIFFERENCE_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_squared_difference_op.py b/tests/st/ops/gpu/test_squared_difference_op.py new file mode 100644 index 00000000000..d36826e1b3c --- /dev/null +++ b/tests/st/ops/gpu/test_squared_difference_op.py @@ -0,0 +1,314 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class SquaredDifference(nn.Cell): + def __init__(self): + super(SquaredDifference, self).__init__() + self.squaredDiff = P.SquaredDifference() + + def construct(self, x, y): + return self.squaredDiff(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_f16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.uniform(10, 20, (3, 4, 5, 2)).astype(np.float16) + input_y = np.random.uniform(40, 50, (3, 4, 5, 2)).astype(np.float16) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_f32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(3, 4, 5, 2).astype(np.float32) + input_y = np.random.rand(3, 4, 5, 2).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(3, 4, 5, 2).astype(np.int32) + input_y = np.random.rand(3, 4, 5, 2).astype(np.int32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.int32) + input_y = np.random.rand(3, 1, 5, 1).astype(np.int32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.float32) + input_y = np.random.rand(3, 1, 5, 1).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.float16) + input_y = np.random.rand(3, 1, 5, 1).astype(np.float16) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_bool(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.bool) + input_y = np.random.uniform(10, 20, (3, 1, 5, 1)).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-6 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_bool(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(3, 4, 5, 2).astype(np.bool) + input_y = np.random.rand(3, 4, 5, 2).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-6 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_int32_f16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.int32) + input_y = np.random.uniform(10, 20, (3, 1, 5, 1)).astype(np.float16) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-3 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_int32_f32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 4, 1, 2).astype(np.int32) + input_y = np.random.uniform(10, 20, (3, 1, 5, 1)).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-6 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_int32_f16(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(2, 4, 3, 2).astype(np.int32) + input_y = np.random.uniform(10, 20, (2, 4, 3, 2)).astype(np.float16) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-3 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nobroadcast_int32_f32(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(2, 4, 3, 2).astype(np.int32) + input_y = np.random.uniform(10, 20, (2, 4, 3, 2)).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-6 + double_check = np.abs(output-expect)/expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f32_scalar_tensor(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(2).astype(np.float32) + input_y = np.random.rand(3, 1, 5, 1).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f32_tensor_tensor(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 2).astype(np.float32) + input_y = np.random.rand(3, 1, 5, 1).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + assert np.all(output == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f32_tensor_tensor_dim_over_7(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(1, 2).astype(np.float32) + input_y = np.random.rand(3, 1, 5, 1, 3, 4, 2, 1).astype(np.float32) + try: + net(Tensor(input_x), Tensor(input_y)) + except RuntimeError: + assert True + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_f32_tensor_tensor_cannot_brocast(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.rand(5, 3).astype(np.float32) + input_y = np.random.rand(3, 1, 5, 1, 3, 4, 2).astype(np.float32) + try: + net(Tensor(input_x), Tensor(input_y)) + except ValueError: + assert True + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_int_f32_precision(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.randint(20, 30, (1, 2)).astype(np.int32) + input_y = np.random.rand(3, 1, 5, 1).astype(np.float32) + output = net(Tensor(input_x), Tensor(input_y)).asnumpy() + diff = input_x-input_y + expect = diff*diff + error = np.ones(shape=np.array(output.shape, dtype=int))*1.0e-3 + double_thousand = np.abs(output-expect)/expect + assert np.all(double_thousand < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_type_error(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + np.random.seed(42) + net = SquaredDifference() + input_x = np.random.randint(20, 30, (1, 2)).astype(np.bool) + input_y = np.random.rand(3, 1, 5, 1).astype(np.bool) + try: + net(Tensor(input_x), Tensor(input_y)) + except TypeError: + assert True