forked from mindspore-Ecosystem/mindspore
!2811 support vm for ParallelConcat
Merge pull request !2811 from jiangjinsheng/vm_parallel_concat
This commit is contained in:
commit
bd60db5c11
|
@ -285,3 +285,4 @@ from .mod import _mod_tbe
|
|||
from .max_pool_grad_grad import _max_pool_grad_grad_tbe
|
||||
from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe
|
||||
from .population_count import _population_count_tbe
|
||||
from .parallel_concat import _parallel_concat_tbe
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ParallelConcat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
parallel_concat_op_info = TBERegOp("ParallelConcat") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("parallel_concat.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("parallel_concat") \
|
||||
.partial_flag(True) \
|
||||
.attr("shape", "required", "listInt", "all") \
|
||||
.attr("N", "required", "int", "all") \
|
||||
.input(0, "values", False, "dynamic", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_5HD) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U8_5HD, DataType.U8_5HD) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I16_5HD, DataType.I16_5HD) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U16_5HD, DataType.U16_5HD) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_5HD, DataType.I32_5HD) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U32_5HD, DataType.U32_5HD) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_5HD, DataType.I64_5HD) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.U64_5HD, DataType.U64_5HD) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \
|
||||
.dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \
|
||||
.dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \
|
||||
.dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \
|
||||
.dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \
|
||||
.dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \
|
||||
.dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \
|
||||
.dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \
|
||||
.dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \
|
||||
.dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \
|
||||
.dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \
|
||||
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \
|
||||
.dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \
|
||||
.dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \
|
||||
.dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \
|
||||
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \
|
||||
.dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \
|
||||
.dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(parallel_concat_op_info)
|
||||
def _parallel_concat_tbe():
|
||||
"""ParallelConcat TBE register"""
|
||||
return
|
|
@ -28,6 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split, TransShape,
|
||||
ParallelConcat,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
|
@ -329,7 +330,8 @@ __all__ = [
|
|||
"InTopK",
|
||||
"LRN",
|
||||
"Mod",
|
||||
"PopulationCount"
|
||||
"PopulationCount",
|
||||
"ParallelConcat",
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -1463,6 +1463,57 @@ class Concat(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class ParallelConcat(PrimitiveWithInfer):
|
||||
r"""
|
||||
Concat tensor in the first dimension.
|
||||
|
||||
Concat input tensors along with the first dimension.
|
||||
|
||||
Note:
|
||||
The input tensors are all required to have size 1 in the first dimension.
|
||||
|
||||
Inputs:
|
||||
- **values** (tuple, list) - Tuple or list of input tensors.
|
||||
|
||||
Outputs:
|
||||
Tensor, data type same as `values`.
|
||||
|
||||
Examples:
|
||||
>>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
|
||||
>>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
|
||||
>>> op = P.ParallelConcat()
|
||||
>>> output = op((data1, data2))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init ParallelConcat"""
|
||||
|
||||
def __infer__(self, values):
|
||||
x_shp = values['shape']
|
||||
x_type = values['dtype']
|
||||
|
||||
validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
|
||||
first_elem = x_shp[0]
|
||||
args = {}
|
||||
for i, elem in enumerate(x_shp[1:]):
|
||||
j = i + 1
|
||||
args[f'x_type[{j}]'] = x_type[j]
|
||||
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
|
||||
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
|
||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||
|
||||
ret_shp = x_shp[0].copy()
|
||||
ret_shp[0] = len(x_shp)
|
||||
self.add_prim_attr('shape', ret_shp)
|
||||
self.add_prim_attr('N', len(x_shp))
|
||||
|
||||
out = {'shape': ret_shp,
|
||||
'dtype': x_type[0],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
def _get_pack_shape(x_shape, x_type, axis, prim_name):
|
||||
"""for pack output shape"""
|
||||
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
|
||||
|
|
|
@ -596,6 +596,15 @@ def test_strided_slice_const():
|
|||
assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all()
|
||||
|
||||
|
||||
class ParallelConcatNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParallelConcatNet, self).__init__()
|
||||
self.parallel_concat = P.ParallelConcat()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.parallel_concat((x1, x2))
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('BitwiseAnd', {
|
||||
'block': P.BitwiseAnd(),
|
||||
|
@ -1875,6 +1884,12 @@ test_case_array_ops = [
|
|||
'desc_inputs': [[1, 3, 24, 24]],
|
||||
'desc_bprop': [[1, 12, 24, 24]],
|
||||
}),
|
||||
('ParallelConcat', {
|
||||
'block': ParallelConcatNet(),
|
||||
'desc_inputs': [Tensor([[1, 2]], mstype.float32),
|
||||
Tensor([[5, 6]], mstype.float32)],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue