From 31a22a633d37fb0c010c8917eecc04824e2582bf Mon Sep 17 00:00:00 2001 From: li-qiyao <1040772198@qq.com> Date: Sun, 9 Oct 2022 21:49:18 +0800 Subject: [PATCH] [feat][assistant][I5EWL4] add new gpu operator Hypot --- .../cuda_impl/cuda_class/hypot_helper.h | 155 ++++++++++++++++++ .../kernel/cuda_impl/cuda_ops/hypot_impl.cu | 144 ++++++++++++++++ .../kernel/cuda_impl/cuda_ops/hypot_impl.cuh | 31 ++++ .../gpu/kernel/math/hypot_gpu_kernel.cc | 100 +++++++++++ .../device/gpu/kernel/math/hypot_gpu_kernel.h | 61 +++++++ .../mindspore/ops/operations/math_ops.py | 2 +- tests/st/ops/gpu/test_hypot_op.py | 70 ++++++++ 7 files changed, 562 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_hypot_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h new file mode 100644 index 00000000000..1156ee788e7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h @@ -0,0 +1,155 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" + +namespace mindspore { +namespace cukernel { +constexpr int MAX_DIMS = 7; +template +class HypotHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit HypotHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + need_broadcast_ = false; + } + + virtual ~HypotHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t INPUT_NUM = 2; + constexpr size_t OUTPUT_NUM = 1; + ResetResource(); + int inp_flag = CalShapesSizeInBytes(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_); + if (inp_flag == -1) { + return inp_flag; + } + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (inp_flag == 1 || out_flag == 1); + + auto inputx_shape = input_shapes[0]; + auto inputy_shape = input_shapes[1]; + auto output_shape = output_shapes[0]; + + ProcessScalar(&inputx_shape, &inputy_shape, &output_shape); + + for (size_t i = 0; i < inputx_shape.size(); i++) { + if (inputx_shape[i] != inputy_shape[i]) { + need_broadcast_ = true; + } + } + + lhs_shape_.resize(MAX_DIMS, 1); + rhs_shape_.resize(MAX_DIMS, 1); + output_shape_.resize(MAX_DIMS, 1); + output_num_ = 1; + for (size_t i = 0; i < output_shape.size(); i++) { + if (need_broadcast_) { + output_shape_[i] = output_shape[i]; + } + output_num_ *= output_shape[i]; + } + int lhs_offset = output_shape.size() - inputx_shape.size(); + for (size_t j = 0; j < inputx_shape.size(); j++) { + if (need_broadcast_) { + if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) { + lhs_shape_[j + lhs_offset] = inputx_shape[j]; + } + } + } + int rhs_offset = output_shape.size() - inputy_shape.size(); + for (size_t k = 0; k < inputy_shape.size(); k++) { + if (need_broadcast_) { + if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) { + rhs_shape_[k + rhs_offset] = inputy_shape[k]; + } + } + } + + return CheckKernelParam(); + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + + T *inputx_ptr = nullptr; + T *inputy_ptr = nullptr; + T *output_ptr = nullptr; + int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &inputx_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(input_ptrs, 1, kernel_name_, &inputy_ptr); + if (flag != 0) { + return flag; + } + + flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &output_ptr); + if (flag != 0) { + return flag; + } + + // call cuda kernel + if (need_broadcast_) { + BroadcastHypot(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } else { + CalHypot(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_, + reinterpret_cast(cuda_stream)); + } + + return 0; + } + + void ProcessScalar(std::vector *x1_shape, std::vector *x2_shape, std::vector *y_shape) { + // If there is a scalar in the inputs, its shape will be [], so it will be treated as [1]. + if (x1_shape->size() == 0) { + x1_shape->insert(x1_shape->begin(), 1); + } + if (x2_shape->size() == 0) { + x2_shape->insert(x2_shape->begin(), 1); + } + if (y_shape->size() == 0) { + y_shape->insert(y_shape->begin(), 1); + } + } + + private: + std::vector lhs_shape_; + std::vector rhs_shape_; + std::vector output_shape_; + bool need_broadcast_; + bool is_null_input_; + size_t output_num_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu new file mode 100644 index 00000000000..f87fe2296b7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cu @@ -0,0 +1,144 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh" + +__constant__ size_t start_cal[5]; +__constant__ size_t end_cal[5]; +__constant__ size_t output_cal[5]; + +template struct HypotFunc { + __device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) { + return hypotf(x1, x2); + } +}; + +template <> struct HypotFunc { + __device__ __host__ __forceinline__ double operator()(const double &x1, + const double &x2) { + return hypot(x1, x2); + } +}; + +template +__global__ void CalHypotKernel(size_t size, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; + pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x1[pos], x2[pos]); + } +} + +__device__ __forceinline__ size_t Index(const size_t &index, + const size_t &dim) { + return dim == 1 ? 0 : index; +} + +template +__global__ void BroadcastHypotKernel( + const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *x1, const T *x2, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) { + size_t i = pos / output_cal[0] % d0; + size_t j = pos / output_cal[1] % d1; + size_t k = pos / output_cal[2] % d2; + size_t l = pos / output_cal[3] % d3; + size_t m = pos / output_cal[4] % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * start_cal[0]; + l_index += Index(j, l1) * start_cal[1]; + l_index += Index(k, l2) * start_cal[2]; + l_index += Index(l, l3) * start_cal[3]; + l_index += Index(m, l4) * start_cal[4]; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + size_t r_index = Index(i, r0) * end_cal[0]; + r_index += Index(j, r1) * end_cal[1]; + r_index += Index(k, r2) * end_cal[2]; + r_index += Index(l, r3) * end_cal[3]; + r_index += Index(m, r4) * end_cal[4]; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); + y[pos] = Func()(x1[l_index], x2[r_index]); + } +} + +template +void CalHypot(size_t size, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream) { + return CalHypotKernel> + <<>>(size, x1, x2, y); +} + +void CalShapeData(const std::vector &start_shape, size_t *output) { + output[4] = start_shape[5] * start_shape[6]; + output[3] = output[4] * start_shape[4]; + output[2] = output[3] * start_shape[3]; + output[1] = output[2] * start_shape[2]; + output[0] = output[1] * start_shape[1]; +} + +template +void BroadcastHypot(const std::vector &x1_shape, + const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, + const T *x2, T *y, const uint32_t &device_id, + cudaStream_t cuda_stream) { + size_t size = 1; + for (auto d : y_shape) { + size *= d; + } + size_t start_dim[5]; + size_t end_dim[5]; + size_t output_dim[5]; + CalShapeData(x1_shape, start_dim); + CalShapeData(x2_shape, end_dim); + CalShapeData(y_shape, output_dim); + cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5); + cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5); + return BroadcastHypotKernel> + <<>>(x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3], + x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0], + x2_shape[1], x2_shape[2], x2_shape[3], x2_shape[4], + x2_shape[5], x2_shape[6], y_shape[0], y_shape[1], + y_shape[2], y_shape[3], y_shape[4], y_shape[5], + y_shape[6], x1, x2, y); +} + +template CUDA_LIB_EXPORT void CalHypot(size_t, const float *, const float *, + float *, const uint32_t &, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalHypot(size_t, const double *, const double *, + double *, const uint32_t &, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void +BroadcastHypot(const std::vector &, const std::vector &, + const std::vector &, const float *, const float *, + float *, const uint32_t &, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void +BroadcastHypot(const std::vector &, const std::vector &, + const std::vector &, const double *, const double *, + double *, const uint32_t &, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh new file mode 100644 index 00000000000..ad9f6ed88c9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT void CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT void BroadcastHypot(const std::vector &x1_shape, const std::vector &x2_shape, + const std::vector &y_shape, const T *x1, const T *x2, T *y, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc new file mode 100644 index 00000000000..16416abf67e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/hypot_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateHypotKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using HypotPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + CreateHypotKernelPtr}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + CreateHypotKernelPtr}}; +} // namespace + +bool HypotGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool HypotGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + kernel_name_ = kernel_ptr->name(); + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + + Resize(base_operator, inputs, outputs); + return true; +} + +int HypotGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + + std::vector> input_shapes; + std::vector> output_shapes; + std::vector inpx_shape = inputs.at(kIndex0)->GetShapeVector(); + std::vector inpy_shape = inputs.at(kIndex1)->GetShapeVector(); + std::vector out_shape = outputs.at(kIndex0)->GetShapeVector(); + input_shapes.emplace_back(inpx_shape); + input_shapes.emplace_back(inpy_shape); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + input_size_list_ = helper_ptr_->GetInputSizeList(); + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector HypotGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Hypot, HypotGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h new file mode 100644 index 00000000000..fb0cba86ef1 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/hypot_gpu_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/hypot.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h" + +namespace mindspore { +namespace kernel { +class HypotGpuKernelMod : public NativeGpuKernelMod { + public: + HypotGpuKernelMod() {} + ~HypotGpuKernelMod() = default; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; + + std::vector GetOpSupport() override; + + private: + std::unique_ptr helper_ptr_{nullptr}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index 8e94671397c..cd09956f9f6 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -2635,7 +2635,7 @@ class Hypot(Primitive): ValueError: If shape of two inputs are not broadcastable. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x1 = Tensor(np.array([3., 5., 7.])) diff --git a/tests/st/ops/gpu/test_hypot_op.py b/tests/st/ops/gpu/test_hypot_op.py new file mode 100644 index 00000000000..08cd19aaae5 --- /dev/null +++ b/tests/st/ops/gpu/test_hypot_op.py @@ -0,0 +1,70 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.context as context +import mindspore.ops.operations.math_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + +class NetHypot(nn.Cell): + def __init__(self): + super(NetHypot, self).__init__() + self.hypot = P.Hypot() + + def construct(self, x1, x2): + return self.hypot(x1, x2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hypot_fp32(): + """ + Feature: Hypot + Description: test cases for Hypot of float32 + Expectation: the results are as expected + """ + x1_np = np.array([3, 4]).astype(np.float32) + x2_np = np.array([4, 3]).astype(np.float32) + input_x1 = Tensor(x1_np) + input_x2 = Tensor(x2_np) + net = NetHypot() + output_ms = net(input_x1, input_x2) + expect_output = np.array([5.0, 5.0]).astype(np.float32) + assert np.allclose(output_ms.asnumpy(), expect_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_hypot_fp64(): + """ + Feature: Hypot + Description: test cases for Hypot of float64 + Expectation: the results are as expected + """ + x1_np = np.array([1.2, 3.4, 2.4, 1.3]).astype(np.float64) + x2_np = np.array([2.3, 1.1, 0.9, 0.3]).astype(np.float64) + input_x1 = Tensor(x1_np) + input_x2 = Tensor(x2_np) + net = NetHypot() + output_ms = net(input_x1, input_x2) + expect_output = np.array([2.59422435, 3.57351368, 2.56320112, 1.33416641]).astype(np.float64) + assert np.allclose(output_ms.asnumpy(), expect_output)