support vm for pack

This commit is contained in:
jiangjinsheng 2020-05-13 09:37:17 +08:00
parent 18c9495000
commit f9bd460c96
5 changed files with 150 additions and 0 deletions

View File

@ -181,3 +181,5 @@ from .sgd import sgd_op_info
from .lars_update import lars_update_op_info
from .bn_training_update_v2 import _bn_training_update_v2_tbe
from .square_sum_all import square_sum_all_op_info
from .pack import _pack_tbe
from .unpack import _unpack_tbe

View File

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pack op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
pack_op_info = TBERegOp("Pack") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("pack.so") \
.compute_cost(10) \
.kernel_name("pack") \
.partial_flag(True) \
.attr("axis", "optional", "int", "all") \
.input(0, "x", False, "dynamic", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I8_NDHWC, DataType.I8_NDHWC) \
.dtype_format(DataType.I16_NDHWC, DataType.I16_NDHWC) \
.dtype_format(DataType.I32_NDHWC, DataType.I32_NDHWC) \
.dtype_format(DataType.I64_NDHWC, DataType.I64_NDHWC) \
.dtype_format(DataType.U8_NDHWC, DataType.U8_NDHWC) \
.dtype_format(DataType.U16_NDHWC, DataType.U16_NDHWC) \
.dtype_format(DataType.U32_NDHWC, DataType.U32_NDHWC) \
.dtype_format(DataType.U64_NDHWC, DataType.U64_NDHWC) \
.dtype_format(DataType.F16_NDHWC, DataType.F16_NDHWC) \
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDHWC) \
.dtype_format(DataType.BOOL_NDHWC, DataType.BOOL_NDHWC) \
.get_op_info()
@op_info_register(pack_op_info)
def _pack_tbe():
"""Pack TBE register"""
return

View File

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Unpack op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
unpack_op_info = TBERegOp("Unpack") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("unpack.so") \
.compute_cost(10) \
.kernel_name("unpack") \
.partial_flag(True) \
.attr("num", "optional", "int", "all") \
.attr("axis", "required", "int", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "dynamic", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(unpack_op_info)
def _unpack_tbe():
"""Unpack TBE register"""
return

View File

@ -499,6 +499,7 @@ class DataType:
BOOL_NCHW = ("bool", "NCHW")
BOOL_NHWC = ("bool", "NHWC")
BOOL_HWCN = ("bool", "HWCN")
BOOL_NDHWC = ("bool", "NDHWC")
I8_None = ("int8", "")
I8_Default = ("int8", "DefaultFormat")
@ -509,6 +510,7 @@ class DataType:
I8_NCHW = ("int8", "NCHW")
I8_NHWC = ("int8", "NHWC")
I8_HWCN = ("int8", "HWCN")
I8_NDHWC = ("int8", "NDHWC")
U8_None = ("uint8", "")
U8_Default = ("uint8", "DefaultFormat")
@ -519,6 +521,7 @@ class DataType:
U8_NCHW = ("uint8", "NCHW")
U8_NHWC = ("uint8", "NHWC")
U8_HWCN = ("uint8", "HWCN")
U8_NDHWC = ("uint8", "NDHWC")
I16_None = ("int16", "")
I16_Default = ("int16", "DefaultFormat")
@ -529,6 +532,7 @@ class DataType:
I16_NCHW = ("int16", "NCHW")
I16_NHWC = ("int16", "NHWC")
I16_HWCN = ("int16", "HWCN")
I16_NDHWC = ("int16", "NDHWC")
U16_None = ("uint16", "")
U16_Default = ("uint16", "DefaultFormat")
@ -539,6 +543,7 @@ class DataType:
U16_NCHW = ("uint16", "NCHW")
U16_NHWC = ("uint16", "NHWC")
U16_HWCN = ("uint16", "HWCN")
U16_NDHWC = ("uint16", "NDHWC")
I32_None = ("int32", "")
I32_Default = ("int32", "DefaultFormat")
@ -549,6 +554,7 @@ class DataType:
I32_NCHW = ("int32", "NCHW")
I32_NHWC = ("int32", "NHWC")
I32_HWCN = ("int32", "HWCN")
I32_NDHWC = ("int32", "NDHWC")
U32_None = ("uint32", "")
U32_Default = ("uint32", "DefaultFormat")
@ -559,6 +565,7 @@ class DataType:
U32_NCHW = ("uint32", "NCHW")
U32_NHWC = ("uint32", "NHWC")
U32_HWCN = ("uint32", "HWCN")
U32_NDHWC = ("uint32", "NDHWC")
I64_None = ("int64", "")
I64_Default = ("int64", "DefaultFormat")
@ -569,6 +576,7 @@ class DataType:
I64_NCHW = ("int64", "NCHW")
I64_NHWC = ("int64", "NHWC")
I64_HWCN = ("int64", "HWCN")
I64_NDHWC = ("int64", "NDHWC")
U64_None = ("uint64", "")
U64_Default = ("uint64", "DefaultFormat")
@ -579,6 +587,7 @@ class DataType:
U64_NCHW = ("uint64", "NCHW")
U64_NHWC = ("uint64", "NHWC")
U64_HWCN = ("uint64", "HWCN")
U64_NDHWC = ("uint64", "NDHWC")
F16_None = ("float16", "")
F16_Default = ("float16", "DefaultFormat")
@ -589,6 +598,7 @@ class DataType:
F16_NCHW = ("float16", "NCHW")
F16_NHWC = ("float16", "NHWC")
F16_HWCN = ("float16", "HWCN")
F16_NDHWC = ("float16", "NDHWC")
F32_None = ("float32", "")
F32_Default = ("float32", "DefaultFormat")
@ -599,6 +609,7 @@ class DataType:
F32_NCHW = ("float32", "NCHW")
F32_NHWC = ("float32", "NHWC")
F32_HWCN = ("float32", "HWCN")
F32_NDHWC = ("float32", "NDHWC")
F64_None = ("float64", "")
F64_Default = ("float64", "DefaultFormat")
@ -609,3 +620,4 @@ class DataType:
F64_NCHW = ("float64", "NCHW")
F64_NHWC = ("float64", "NHWC")
F64_HWCN = ("float64", "HWCN")
F64_NDHWC = ("float64", "NDHWC")

View File

@ -227,6 +227,23 @@ class SpaceToBatchNet(Cell):
return self.space_to_batch(x)
class PackNet(Cell):
def __init__(self):
super(PackNet, self).__init__()
self.pack = P.Pack()
def construct(self, x):
return self.pack((x, x))
class UnpackNet(Cell):
def __init__(self):
super(UnpackNet, self).__init__()
self.unpack = P.Unpack()
def construct(self, x):
return self.unpack(x)
test_case_array_ops = [
('CustNet1', {
'block': CustNet1(),
@ -249,6 +266,12 @@ test_case_array_ops = [
('SpaceToBatchNet', {
'block': SpaceToBatchNet(),
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
('PackNet', {
'block': PackNet(),
'desc_inputs': [Tensor(np.array([[[1, 2], [3, 4]]]).astype(np.float16))]}),
('UnpackNet', {
'block': UnpackNet(),
'desc_inputs': [Tensor(np.array([[1, 2], [3, 4]]).astype(np.float16))]}),
]
test_case_lists = [test_case_array_ops]