From 95f2743d85f86dfab4342d4f7268905dd6e791c8 Mon Sep 17 00:00:00 2001 From: fandawei Date: Wed, 8 Feb 2023 22:26:07 +0800 Subject: [PATCH] modify the logic of float_power for complex --- .../mindspore/ops/function/math_func.py | 26 ++++++++----------- tests/st/ops/test_ops_float_power.py | 19 -------------- tests/st/tensor/test_float_power.py | 19 -------------- 3 files changed, 11 insertions(+), 53 deletions(-) diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 7a46dc7d579..305aefbdeee 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -954,11 +954,8 @@ def divide(x, other, *, rounding_mode=None): def float_power(x, exponent): """ Computes `x` to the power of the exponent. - For the real number type, use mindspore.float64 to calculate. - For the complex type, use the same type of calculation as the input data. - - .. Note:: - On GPU, complex dtypes are not supported. + For the real number type, cast x and expoent to mindspore.float64 to calculate. + Currently, complex type calculation is not supported. Args: x (Union[Tensor, Number]): The first input is a tensor or a number. @@ -983,22 +980,21 @@ def float_power(x, exponent): >>> print(output) [2.25 0. 4. ] """ - if not (isinstance(x, (Tensor, Tensor_)) or isinstance(exponent, (Tensor, Tensor_))): + if not (isinstance(x, Tensor) or isinstance(exponent, Tensor)): raise TypeError("At least one of the types of inputs must be tensor, " + \ f"but the type of 'x' got is {type(x)}, " + \ f"and the type of 'exponent' is {type(exponent)}.") - if not isinstance(x, (Tensor, Tensor_, numbers.Number)): + if not isinstance(x, (Tensor, numbers.Number)): raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.") - if not isinstance(exponent, (Tensor, Tensor_, numbers.Number)): + if not isinstance(exponent, (Tensor, numbers.Number)): raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.") - if isinstance(x, (Tensor, Tensor_)) and is_complex(x) and isinstance(exponent, numbers.Number): - exponent = cast_(exponent, x.dtype) - elif isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent) and isinstance(x, numbers.Number): - x = cast_(x, exponent.dtype) - # If both x and exponent are complex Tensor, no processing is required. - elif not (isinstance(x, (Tensor, Tensor_)) and is_complex(x) and \ - isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent)): + if (isinstance(x, Tensor) and is_complex(x)) or \ + (isinstance(exponent, Tensor) and is_complex(exponent)) or \ + isinstance(x, complex) or isinstance(exponent, complex): + x = cast_(x, mstype.complex128) + exponent = cast_(exponent, mstype.complex128) + else: x = cast_(x, mstype.float64) exponent = cast_(exponent, mstype.float64) return pow(x, exponent) diff --git a/tests/st/ops/test_ops_float_power.py b/tests/st/ops/test_ops_float_power.py index f2dbe6ef7b3..78588d6dbda 100644 --- a/tests/st/ops/test_ops_float_power.py +++ b/tests/st/ops/test_ops_float_power.py @@ -46,22 +46,3 @@ def test_float_power_real(mode): except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32) assert output_case.asnumpy().dtype == np.float64 assert np.allclose(output_case.asnumpy(), except_case) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_arm_cpu -@pytest.mark.env_onecard -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_float_power_complex(mode): - """ - Feature: ops.float_power - Description: Verify the result of float_power - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64) - output_case = net(input_case, 2) - except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64) - assert np.allclose(output_case.asnumpy(), except_case) diff --git a/tests/st/tensor/test_float_power.py b/tests/st/tensor/test_float_power.py index 4db89b374ca..fccbb3bdfd6 100644 --- a/tests/st/tensor/test_float_power.py +++ b/tests/st/tensor/test_float_power.py @@ -44,22 +44,3 @@ def test_float_power_real(mode): except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32) assert output_case.asnumpy().dtype == np.float64 assert np.allclose(output_case.asnumpy(), except_case) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_arm_cpu -@pytest.mark.env_onecard -@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) -def test_float_power_complex(mode): - """ - Feature: tensor.float_power - Description: Verify the result of float_power - Expectation: success - """ - ms.set_context(mode=mode) - net = Net() - input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64) - output_case = net(input_case, 2) - except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64) - assert np.allclose(output_case.asnumpy(), except_case)