modify the logic of float_power for complex

This commit is contained in:
fandawei 2023-02-08 22:26:07 +08:00
parent c7dce3e513
commit 95f2743d85
3 changed files with 11 additions and 53 deletions

View File

@ -954,11 +954,8 @@ def divide(x, other, *, rounding_mode=None):
def float_power(x, exponent): def float_power(x, exponent):
""" """
Computes `x` to the power of the exponent. Computes `x` to the power of the exponent.
For the real number type, use mindspore.float64 to calculate. For the real number type, cast x and expoent to mindspore.float64 to calculate.
For the complex type, use the same type of calculation as the input data. Currently, complex type calculation is not supported.
.. Note::
On GPU, complex dtypes are not supported.
Args: Args:
x (Union[Tensor, Number]): The first input is a tensor or a number. x (Union[Tensor, Number]): The first input is a tensor or a number.
@ -983,22 +980,21 @@ def float_power(x, exponent):
>>> print(output) >>> print(output)
[2.25 0. 4. ] [2.25 0. 4. ]
""" """
if not (isinstance(x, (Tensor, Tensor_)) or isinstance(exponent, (Tensor, Tensor_))): if not (isinstance(x, Tensor) or isinstance(exponent, Tensor)):
raise TypeError("At least one of the types of inputs must be tensor, " + \ raise TypeError("At least one of the types of inputs must be tensor, " + \
f"but the type of 'x' got is {type(x)}, " + \ f"but the type of 'x' got is {type(x)}, " + \
f"and the type of 'exponent' is {type(exponent)}.") f"and the type of 'exponent' is {type(exponent)}.")
if not isinstance(x, (Tensor, Tensor_, numbers.Number)): if not isinstance(x, (Tensor, numbers.Number)):
raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.") raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.")
if not isinstance(exponent, (Tensor, Tensor_, numbers.Number)): if not isinstance(exponent, (Tensor, numbers.Number)):
raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.") raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.")
if isinstance(x, (Tensor, Tensor_)) and is_complex(x) and isinstance(exponent, numbers.Number): if (isinstance(x, Tensor) and is_complex(x)) or \
exponent = cast_(exponent, x.dtype) (isinstance(exponent, Tensor) and is_complex(exponent)) or \
elif isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent) and isinstance(x, numbers.Number): isinstance(x, complex) or isinstance(exponent, complex):
x = cast_(x, exponent.dtype) x = cast_(x, mstype.complex128)
# If both x and exponent are complex Tensor, no processing is required. exponent = cast_(exponent, mstype.complex128)
elif not (isinstance(x, (Tensor, Tensor_)) and is_complex(x) and \ else:
isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent)):
x = cast_(x, mstype.float64) x = cast_(x, mstype.float64)
exponent = cast_(exponent, mstype.float64) exponent = cast_(exponent, mstype.float64)
return pow(x, exponent) return pow(x, exponent)

View File

@ -46,22 +46,3 @@ def test_float_power_real(mode):
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32) except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
assert output_case.asnumpy().dtype == np.float64 assert output_case.asnumpy().dtype == np.float64
assert np.allclose(output_case.asnumpy(), except_case) assert np.allclose(output_case.asnumpy(), except_case)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_complex(mode):
"""
Feature: ops.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
output_case = net(input_case, 2)
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
assert np.allclose(output_case.asnumpy(), except_case)

View File

@ -44,22 +44,3 @@ def test_float_power_real(mode):
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32) except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
assert output_case.asnumpy().dtype == np.float64 assert output_case.asnumpy().dtype == np.float64
assert np.allclose(output_case.asnumpy(), except_case) assert np.allclose(output_case.asnumpy(), except_case)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_float_power_complex(mode):
"""
Feature: tensor.float_power
Description: Verify the result of float_power
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
output_case = net(input_case, 2)
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
assert np.allclose(output_case.asnumpy(), except_case)