!12399 Add type support to Squeeze gpu op

From: @peilin-wang
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-02-19 21:33:20 +08:00 committed by Gitee
commit feb07198e7
2 changed files with 73 additions and 42 deletions

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,9 +22,14 @@ squeeze_op_info = AkgGpuRegOp("Squeeze") \
.attr("axis", "optional", "listInt") \ .attr("axis", "optional", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \ .dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,67 +13,93 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") class SqueezeNet(nn.Cell):
class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(SqueezeNet, self).__init__()
self.squeeze = P.Squeeze() self.squeeze = P.Squeeze()
def construct(self, tensor): def construct(self, tensor):
return self.squeeze(tensor) return self.squeeze(tensor)
def test_net_bool(): def squeeze(nptype):
x = np.random.randn(1, 16, 1, 1).astype(np.bool) context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = Net()
np.random.seed(0)
x = np.random.randn(1, 16, 1, 1).astype(nptype)
net = SqueezeNet()
output = net(Tensor(x)) output = net(Tensor(x))
print(output.asnumpy())
assert np.all(output.asnumpy() == x.squeeze()) assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_bool():
squeeze(np.bool)
def test_net_uint8(): @pytest.mark.level0
x = np.random.randn(1, 16, 1, 1).astype(np.uint8) @pytest.mark.platform_x86_gpu_training
net = Net() @pytest.mark.env_onecard
output = net(Tensor(x)) def test_squeeze_uint8():
print(output.asnumpy()) squeeze(np.uint8)
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_uint16():
squeeze(np.uint16)
def test_net_int16(): @pytest.mark.level0
x = np.random.randn(1, 16, 1, 1).astype(np.int16) @pytest.mark.platform_x86_gpu_training
net = Net() @pytest.mark.env_onecard
output = net(Tensor(x)) def test_squeeze_uint32():
print(output.asnumpy()) squeeze(np.uint32)
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int8():
squeeze(np.int8)
def test_net_int32(): @pytest.mark.level0
x = np.random.randn(1, 16, 1, 1).astype(np.int32) @pytest.mark.platform_x86_gpu_training
net = Net() @pytest.mark.env_onecard
output = net(Tensor(x)) def test_squeeze_int16():
print(output.asnumpy()) squeeze(np.int16)
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_int32():
squeeze(np.int32)
def test_net_float16(): @pytest.mark.level0
x = np.random.randn(1, 16, 1, 1).astype(np.float16) @pytest.mark.platform_x86_gpu_training
net = Net() @pytest.mark.env_onecard
output = net(Tensor(x)) def test_squeeze_int64():
print(output.asnumpy()) squeeze(np.int64)
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float16():
squeeze(np.float16)
def test_net_float32(): @pytest.mark.level0
x = np.random.randn(1, 16, 1, 1).astype(np.float32) @pytest.mark.platform_x86_gpu_training
net = Net() @pytest.mark.env_onecard
output = net(Tensor(x)) def test_squeeze_float32():
print(output.asnumpy()) squeeze(np.float32)
assert np.all(output.asnumpy() == x.squeeze())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_squeeze_float64():
squeeze(np.float64)