forked from mindspore-Ecosystem/mindspore
!1111 support vm for pack and unpack
Merge pull request !1111 from jiangjinsheng/vm_pack
This commit is contained in:
commit
fe8b59f26b
|
@ -182,3 +182,5 @@ from .sgd import sgd_op_info
|
||||||
from .lars_update import lars_update_op_info
|
from .lars_update import lars_update_op_info
|
||||||
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
from .bn_training_update_v2 import _bn_training_update_v2_tbe
|
||||||
from .square_sum_all import square_sum_all_op_info
|
from .square_sum_all import square_sum_all_op_info
|
||||||
|
from .pack import _pack_tbe
|
||||||
|
from .unpack import _unpack_tbe
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Pack op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
pack_op_info = TBERegOp("Pack") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("pack.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("pack") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("axis", "optional", "int", "all") \
|
||||||
|
.input(0, "x", False, "dynamic", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_NDHWC, DataType.I8_NDHWC) \
|
||||||
|
.dtype_format(DataType.I16_NDHWC, DataType.I16_NDHWC) \
|
||||||
|
.dtype_format(DataType.I32_NDHWC, DataType.I32_NDHWC) \
|
||||||
|
.dtype_format(DataType.I64_NDHWC, DataType.I64_NDHWC) \
|
||||||
|
.dtype_format(DataType.U8_NDHWC, DataType.U8_NDHWC) \
|
||||||
|
.dtype_format(DataType.U16_NDHWC, DataType.U16_NDHWC) \
|
||||||
|
.dtype_format(DataType.U32_NDHWC, DataType.U32_NDHWC) \
|
||||||
|
.dtype_format(DataType.U64_NDHWC, DataType.U64_NDHWC) \
|
||||||
|
.dtype_format(DataType.F16_NDHWC, DataType.F16_NDHWC) \
|
||||||
|
.dtype_format(DataType.F32_NDHWC, DataType.F32_NDHWC) \
|
||||||
|
.dtype_format(DataType.BOOL_NDHWC, DataType.BOOL_NDHWC) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(pack_op_info)
|
||||||
|
def _pack_tbe():
|
||||||
|
"""Pack TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Unpack op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
unpack_op_info = TBERegOp("Unpack") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("unpack.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("unpack") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("num", "optional", "int", "all") \
|
||||||
|
.attr("axis", "required", "int", "all") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "dynamic", "all") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||||
|
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
|
||||||
|
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||||
|
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
|
||||||
|
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
|
||||||
|
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
|
||||||
|
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
|
||||||
|
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(unpack_op_info)
|
||||||
|
def _unpack_tbe():
|
||||||
|
"""Unpack TBE register"""
|
||||||
|
return
|
|
@ -499,6 +499,7 @@ class DataType:
|
||||||
BOOL_NCHW = ("bool", "NCHW")
|
BOOL_NCHW = ("bool", "NCHW")
|
||||||
BOOL_NHWC = ("bool", "NHWC")
|
BOOL_NHWC = ("bool", "NHWC")
|
||||||
BOOL_HWCN = ("bool", "HWCN")
|
BOOL_HWCN = ("bool", "HWCN")
|
||||||
|
BOOL_NDHWC = ("bool", "NDHWC")
|
||||||
|
|
||||||
I8_None = ("int8", "")
|
I8_None = ("int8", "")
|
||||||
I8_Default = ("int8", "DefaultFormat")
|
I8_Default = ("int8", "DefaultFormat")
|
||||||
|
@ -509,6 +510,7 @@ class DataType:
|
||||||
I8_NCHW = ("int8", "NCHW")
|
I8_NCHW = ("int8", "NCHW")
|
||||||
I8_NHWC = ("int8", "NHWC")
|
I8_NHWC = ("int8", "NHWC")
|
||||||
I8_HWCN = ("int8", "HWCN")
|
I8_HWCN = ("int8", "HWCN")
|
||||||
|
I8_NDHWC = ("int8", "NDHWC")
|
||||||
|
|
||||||
U8_None = ("uint8", "")
|
U8_None = ("uint8", "")
|
||||||
U8_Default = ("uint8", "DefaultFormat")
|
U8_Default = ("uint8", "DefaultFormat")
|
||||||
|
@ -519,6 +521,7 @@ class DataType:
|
||||||
U8_NCHW = ("uint8", "NCHW")
|
U8_NCHW = ("uint8", "NCHW")
|
||||||
U8_NHWC = ("uint8", "NHWC")
|
U8_NHWC = ("uint8", "NHWC")
|
||||||
U8_HWCN = ("uint8", "HWCN")
|
U8_HWCN = ("uint8", "HWCN")
|
||||||
|
U8_NDHWC = ("uint8", "NDHWC")
|
||||||
|
|
||||||
I16_None = ("int16", "")
|
I16_None = ("int16", "")
|
||||||
I16_Default = ("int16", "DefaultFormat")
|
I16_Default = ("int16", "DefaultFormat")
|
||||||
|
@ -529,6 +532,7 @@ class DataType:
|
||||||
I16_NCHW = ("int16", "NCHW")
|
I16_NCHW = ("int16", "NCHW")
|
||||||
I16_NHWC = ("int16", "NHWC")
|
I16_NHWC = ("int16", "NHWC")
|
||||||
I16_HWCN = ("int16", "HWCN")
|
I16_HWCN = ("int16", "HWCN")
|
||||||
|
I16_NDHWC = ("int16", "NDHWC")
|
||||||
|
|
||||||
U16_None = ("uint16", "")
|
U16_None = ("uint16", "")
|
||||||
U16_Default = ("uint16", "DefaultFormat")
|
U16_Default = ("uint16", "DefaultFormat")
|
||||||
|
@ -539,6 +543,7 @@ class DataType:
|
||||||
U16_NCHW = ("uint16", "NCHW")
|
U16_NCHW = ("uint16", "NCHW")
|
||||||
U16_NHWC = ("uint16", "NHWC")
|
U16_NHWC = ("uint16", "NHWC")
|
||||||
U16_HWCN = ("uint16", "HWCN")
|
U16_HWCN = ("uint16", "HWCN")
|
||||||
|
U16_NDHWC = ("uint16", "NDHWC")
|
||||||
|
|
||||||
I32_None = ("int32", "")
|
I32_None = ("int32", "")
|
||||||
I32_Default = ("int32", "DefaultFormat")
|
I32_Default = ("int32", "DefaultFormat")
|
||||||
|
@ -549,6 +554,7 @@ class DataType:
|
||||||
I32_NCHW = ("int32", "NCHW")
|
I32_NCHW = ("int32", "NCHW")
|
||||||
I32_NHWC = ("int32", "NHWC")
|
I32_NHWC = ("int32", "NHWC")
|
||||||
I32_HWCN = ("int32", "HWCN")
|
I32_HWCN = ("int32", "HWCN")
|
||||||
|
I32_NDHWC = ("int32", "NDHWC")
|
||||||
|
|
||||||
U32_None = ("uint32", "")
|
U32_None = ("uint32", "")
|
||||||
U32_Default = ("uint32", "DefaultFormat")
|
U32_Default = ("uint32", "DefaultFormat")
|
||||||
|
@ -559,6 +565,7 @@ class DataType:
|
||||||
U32_NCHW = ("uint32", "NCHW")
|
U32_NCHW = ("uint32", "NCHW")
|
||||||
U32_NHWC = ("uint32", "NHWC")
|
U32_NHWC = ("uint32", "NHWC")
|
||||||
U32_HWCN = ("uint32", "HWCN")
|
U32_HWCN = ("uint32", "HWCN")
|
||||||
|
U32_NDHWC = ("uint32", "NDHWC")
|
||||||
|
|
||||||
I64_None = ("int64", "")
|
I64_None = ("int64", "")
|
||||||
I64_Default = ("int64", "DefaultFormat")
|
I64_Default = ("int64", "DefaultFormat")
|
||||||
|
@ -569,6 +576,7 @@ class DataType:
|
||||||
I64_NCHW = ("int64", "NCHW")
|
I64_NCHW = ("int64", "NCHW")
|
||||||
I64_NHWC = ("int64", "NHWC")
|
I64_NHWC = ("int64", "NHWC")
|
||||||
I64_HWCN = ("int64", "HWCN")
|
I64_HWCN = ("int64", "HWCN")
|
||||||
|
I64_NDHWC = ("int64", "NDHWC")
|
||||||
|
|
||||||
U64_None = ("uint64", "")
|
U64_None = ("uint64", "")
|
||||||
U64_Default = ("uint64", "DefaultFormat")
|
U64_Default = ("uint64", "DefaultFormat")
|
||||||
|
@ -579,6 +587,7 @@ class DataType:
|
||||||
U64_NCHW = ("uint64", "NCHW")
|
U64_NCHW = ("uint64", "NCHW")
|
||||||
U64_NHWC = ("uint64", "NHWC")
|
U64_NHWC = ("uint64", "NHWC")
|
||||||
U64_HWCN = ("uint64", "HWCN")
|
U64_HWCN = ("uint64", "HWCN")
|
||||||
|
U64_NDHWC = ("uint64", "NDHWC")
|
||||||
|
|
||||||
F16_None = ("float16", "")
|
F16_None = ("float16", "")
|
||||||
F16_Default = ("float16", "DefaultFormat")
|
F16_Default = ("float16", "DefaultFormat")
|
||||||
|
@ -589,6 +598,7 @@ class DataType:
|
||||||
F16_NCHW = ("float16", "NCHW")
|
F16_NCHW = ("float16", "NCHW")
|
||||||
F16_NHWC = ("float16", "NHWC")
|
F16_NHWC = ("float16", "NHWC")
|
||||||
F16_HWCN = ("float16", "HWCN")
|
F16_HWCN = ("float16", "HWCN")
|
||||||
|
F16_NDHWC = ("float16", "NDHWC")
|
||||||
|
|
||||||
F32_None = ("float32", "")
|
F32_None = ("float32", "")
|
||||||
F32_Default = ("float32", "DefaultFormat")
|
F32_Default = ("float32", "DefaultFormat")
|
||||||
|
@ -599,6 +609,7 @@ class DataType:
|
||||||
F32_NCHW = ("float32", "NCHW")
|
F32_NCHW = ("float32", "NCHW")
|
||||||
F32_NHWC = ("float32", "NHWC")
|
F32_NHWC = ("float32", "NHWC")
|
||||||
F32_HWCN = ("float32", "HWCN")
|
F32_HWCN = ("float32", "HWCN")
|
||||||
|
F32_NDHWC = ("float32", "NDHWC")
|
||||||
|
|
||||||
F64_None = ("float64", "")
|
F64_None = ("float64", "")
|
||||||
F64_Default = ("float64", "DefaultFormat")
|
F64_Default = ("float64", "DefaultFormat")
|
||||||
|
@ -609,3 +620,4 @@ class DataType:
|
||||||
F64_NCHW = ("float64", "NCHW")
|
F64_NCHW = ("float64", "NCHW")
|
||||||
F64_NHWC = ("float64", "NHWC")
|
F64_NHWC = ("float64", "NHWC")
|
||||||
F64_HWCN = ("float64", "HWCN")
|
F64_HWCN = ("float64", "HWCN")
|
||||||
|
F64_NDHWC = ("float64", "NDHWC")
|
||||||
|
|
|
@ -227,6 +227,23 @@ class SpaceToBatchNet(Cell):
|
||||||
return self.space_to_batch(x)
|
return self.space_to_batch(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PackNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(PackNet, self).__init__()
|
||||||
|
self.pack = P.Pack()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.pack((x, x))
|
||||||
|
|
||||||
|
|
||||||
|
class UnpackNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(UnpackNet, self).__init__()
|
||||||
|
self.unpack = P.Unpack()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.unpack(x)
|
||||||
|
|
||||||
test_case_array_ops = [
|
test_case_array_ops = [
|
||||||
('CustNet1', {
|
('CustNet1', {
|
||||||
'block': CustNet1(),
|
'block': CustNet1(),
|
||||||
|
@ -249,6 +266,12 @@ test_case_array_ops = [
|
||||||
('SpaceToBatchNet', {
|
('SpaceToBatchNet', {
|
||||||
'block': SpaceToBatchNet(),
|
'block': SpaceToBatchNet(),
|
||||||
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
|
'desc_inputs': [Tensor(np.array([[[[1, 2], [3, 4]]]]).astype(np.float16))]}),
|
||||||
|
('PackNet', {
|
||||||
|
'block': PackNet(),
|
||||||
|
'desc_inputs': [Tensor(np.array([[[1, 2], [3, 4]]]).astype(np.float16))]}),
|
||||||
|
('UnpackNet', {
|
||||||
|
'block': UnpackNet(),
|
||||||
|
'desc_inputs': [Tensor(np.array([[1, 2], [3, 4]]).astype(np.float16))]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_lists = [test_case_array_ops]
|
test_case_lists = [test_case_array_ops]
|
||||||
|
|
Loading…
Reference in New Issue