forked from mindspore-Ecosystem/mindspore
modify the logic of float_power for complex
This commit is contained in:
parent
c7dce3e513
commit
95f2743d85
|
@ -954,11 +954,8 @@ def divide(x, other, *, rounding_mode=None):
|
|||
def float_power(x, exponent):
|
||||
"""
|
||||
Computes `x` to the power of the exponent.
|
||||
For the real number type, use mindspore.float64 to calculate.
|
||||
For the complex type, use the same type of calculation as the input data.
|
||||
|
||||
.. Note::
|
||||
On GPU, complex dtypes are not supported.
|
||||
For the real number type, cast x and expoent to mindspore.float64 to calculate.
|
||||
Currently, complex type calculation is not supported.
|
||||
|
||||
Args:
|
||||
x (Union[Tensor, Number]): The first input is a tensor or a number.
|
||||
|
@ -983,22 +980,21 @@ def float_power(x, exponent):
|
|||
>>> print(output)
|
||||
[2.25 0. 4. ]
|
||||
"""
|
||||
if not (isinstance(x, (Tensor, Tensor_)) or isinstance(exponent, (Tensor, Tensor_))):
|
||||
if not (isinstance(x, Tensor) or isinstance(exponent, Tensor)):
|
||||
raise TypeError("At least one of the types of inputs must be tensor, " + \
|
||||
f"but the type of 'x' got is {type(x)}, " + \
|
||||
f"and the type of 'exponent' is {type(exponent)}.")
|
||||
if not isinstance(x, (Tensor, Tensor_, numbers.Number)):
|
||||
if not isinstance(x, (Tensor, numbers.Number)):
|
||||
raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.")
|
||||
if not isinstance(exponent, (Tensor, Tensor_, numbers.Number)):
|
||||
if not isinstance(exponent, (Tensor, numbers.Number)):
|
||||
raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.")
|
||||
|
||||
if isinstance(x, (Tensor, Tensor_)) and is_complex(x) and isinstance(exponent, numbers.Number):
|
||||
exponent = cast_(exponent, x.dtype)
|
||||
elif isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent) and isinstance(x, numbers.Number):
|
||||
x = cast_(x, exponent.dtype)
|
||||
# If both x and exponent are complex Tensor, no processing is required.
|
||||
elif not (isinstance(x, (Tensor, Tensor_)) and is_complex(x) and \
|
||||
isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent)):
|
||||
if (isinstance(x, Tensor) and is_complex(x)) or \
|
||||
(isinstance(exponent, Tensor) and is_complex(exponent)) or \
|
||||
isinstance(x, complex) or isinstance(exponent, complex):
|
||||
x = cast_(x, mstype.complex128)
|
||||
exponent = cast_(exponent, mstype.complex128)
|
||||
else:
|
||||
x = cast_(x, mstype.float64)
|
||||
exponent = cast_(exponent, mstype.float64)
|
||||
return pow(x, exponent)
|
||||
|
|
|
@ -46,22 +46,3 @@ def test_float_power_real(mode):
|
|||
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
||||
assert output_case.asnumpy().dtype == np.float64
|
||||
assert np.allclose(output_case.asnumpy(), except_case)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_float_power_complex(mode):
|
||||
"""
|
||||
Feature: ops.float_power
|
||||
Description: Verify the result of float_power
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
|
||||
output_case = net(input_case, 2)
|
||||
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
|
||||
assert np.allclose(output_case.asnumpy(), except_case)
|
||||
|
|
|
@ -44,22 +44,3 @@ def test_float_power_real(mode):
|
|||
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
||||
assert output_case.asnumpy().dtype == np.float64
|
||||
assert np.allclose(output_case.asnumpy(), except_case)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_float_power_complex(mode):
|
||||
"""
|
||||
Feature: tensor.float_power
|
||||
Description: Verify the result of float_power
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
|
||||
output_case = net(input_case, 2)
|
||||
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
|
||||
assert np.allclose(output_case.asnumpy(), except_case)
|
||||
|
|
Loading…
Reference in New Issue