forked from mindspore-Ecosystem/mindspore
modify the logic of float_power for complex
This commit is contained in:
parent
c7dce3e513
commit
95f2743d85
|
@ -954,11 +954,8 @@ def divide(x, other, *, rounding_mode=None):
|
||||||
def float_power(x, exponent):
|
def float_power(x, exponent):
|
||||||
"""
|
"""
|
||||||
Computes `x` to the power of the exponent.
|
Computes `x` to the power of the exponent.
|
||||||
For the real number type, use mindspore.float64 to calculate.
|
For the real number type, cast x and expoent to mindspore.float64 to calculate.
|
||||||
For the complex type, use the same type of calculation as the input data.
|
Currently, complex type calculation is not supported.
|
||||||
|
|
||||||
.. Note::
|
|
||||||
On GPU, complex dtypes are not supported.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Union[Tensor, Number]): The first input is a tensor or a number.
|
x (Union[Tensor, Number]): The first input is a tensor or a number.
|
||||||
|
@ -983,22 +980,21 @@ def float_power(x, exponent):
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
[2.25 0. 4. ]
|
[2.25 0. 4. ]
|
||||||
"""
|
"""
|
||||||
if not (isinstance(x, (Tensor, Tensor_)) or isinstance(exponent, (Tensor, Tensor_))):
|
if not (isinstance(x, Tensor) or isinstance(exponent, Tensor)):
|
||||||
raise TypeError("At least one of the types of inputs must be tensor, " + \
|
raise TypeError("At least one of the types of inputs must be tensor, " + \
|
||||||
f"but the type of 'x' got is {type(x)}, " + \
|
f"but the type of 'x' got is {type(x)}, " + \
|
||||||
f"and the type of 'exponent' is {type(exponent)}.")
|
f"and the type of 'exponent' is {type(exponent)}.")
|
||||||
if not isinstance(x, (Tensor, Tensor_, numbers.Number)):
|
if not isinstance(x, (Tensor, numbers.Number)):
|
||||||
raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.")
|
raise TypeError(f"The type of 'x' must be Tensor or Number, but got {type(x)}.")
|
||||||
if not isinstance(exponent, (Tensor, Tensor_, numbers.Number)):
|
if not isinstance(exponent, (Tensor, numbers.Number)):
|
||||||
raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.")
|
raise TypeError(f"The type of 'exponent' must be Tensor or Number, but got {type(exponent)}.")
|
||||||
|
|
||||||
if isinstance(x, (Tensor, Tensor_)) and is_complex(x) and isinstance(exponent, numbers.Number):
|
if (isinstance(x, Tensor) and is_complex(x)) or \
|
||||||
exponent = cast_(exponent, x.dtype)
|
(isinstance(exponent, Tensor) and is_complex(exponent)) or \
|
||||||
elif isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent) and isinstance(x, numbers.Number):
|
isinstance(x, complex) or isinstance(exponent, complex):
|
||||||
x = cast_(x, exponent.dtype)
|
x = cast_(x, mstype.complex128)
|
||||||
# If both x and exponent are complex Tensor, no processing is required.
|
exponent = cast_(exponent, mstype.complex128)
|
||||||
elif not (isinstance(x, (Tensor, Tensor_)) and is_complex(x) and \
|
else:
|
||||||
isinstance(exponent, (Tensor, Tensor_)) and is_complex(exponent)):
|
|
||||||
x = cast_(x, mstype.float64)
|
x = cast_(x, mstype.float64)
|
||||||
exponent = cast_(exponent, mstype.float64)
|
exponent = cast_(exponent, mstype.float64)
|
||||||
return pow(x, exponent)
|
return pow(x, exponent)
|
||||||
|
|
|
@ -46,22 +46,3 @@ def test_float_power_real(mode):
|
||||||
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
||||||
assert output_case.asnumpy().dtype == np.float64
|
assert output_case.asnumpy().dtype == np.float64
|
||||||
assert np.allclose(output_case.asnumpy(), except_case)
|
assert np.allclose(output_case.asnumpy(), except_case)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.platform_arm_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
|
||||||
def test_float_power_complex(mode):
|
|
||||||
"""
|
|
||||||
Feature: ops.float_power
|
|
||||||
Description: Verify the result of float_power
|
|
||||||
Expectation: success
|
|
||||||
"""
|
|
||||||
ms.set_context(mode=mode)
|
|
||||||
net = Net()
|
|
||||||
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
|
|
||||||
output_case = net(input_case, 2)
|
|
||||||
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
|
|
||||||
assert np.allclose(output_case.asnumpy(), except_case)
|
|
||||||
|
|
|
@ -44,22 +44,3 @@ def test_float_power_real(mode):
|
||||||
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
except_case = np.array([9.0000, 4.0000, 1.0000, 1.0000, 4.0000, 9.0000], dtype=np.float32)
|
||||||
assert output_case.asnumpy().dtype == np.float64
|
assert output_case.asnumpy().dtype == np.float64
|
||||||
assert np.allclose(output_case.asnumpy(), except_case)
|
assert np.allclose(output_case.asnumpy(), except_case)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.platform_arm_cpu
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
|
||||||
def test_float_power_complex(mode):
|
|
||||||
"""
|
|
||||||
Feature: tensor.float_power
|
|
||||||
Description: Verify the result of float_power
|
|
||||||
Expectation: success
|
|
||||||
"""
|
|
||||||
ms.set_context(mode=mode)
|
|
||||||
net = Net()
|
|
||||||
input_case = ms.Tensor(np.array([complex(2, 3), complex(3, 4)]), ms.complex64)
|
|
||||||
output_case = net(input_case, 2)
|
|
||||||
except_case = np.array([complex(-5, 12), complex(-7, 24)], dtype=np.complex64)
|
|
||||||
assert np.allclose(output_case.asnumpy(), except_case)
|
|
||||||
|
|
Loading…
Reference in New Issue