forked from mindspore-Ecosystem/mindspore
!25457 add broadcast GPU float64 registration
Merge pull request !25457 from zhujingxuan/master
This commit is contained in:
commit
ba0e1a810e
|
@ -27,6 +27,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
Minimum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Maximum,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
BroadcastOpGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, double)
|
||||
|
|
|
@ -297,6 +297,66 @@ def test_broadcast_diff_dims():
|
|||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_broadcast_diff_dims_float64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for broadcast operations execpted for DivNoNan
|
||||
Expectation: the result match numpy results
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
np.random.seed(42)
|
||||
x1_np = np.random.rand(2).astype(np.float32)
|
||||
x2_np = np.random.rand(2, 1).astype(np.float32)
|
||||
|
||||
output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.minimum(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.maximum(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np > x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np < x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.power(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np / x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np * x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np - x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Mod()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.fmod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.FloorMod()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.mod(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Atan2()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.arctan2(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue