Add dynamic shape support for the operator Concat

This commit is contained in:
hedongdong 2021-01-21 20:36:29 +08:00
parent a5cde3fea3
commit 8241dfa443
13 changed files with 260 additions and 19 deletions

View File

@ -21,6 +21,7 @@ namespace mindspore {
namespace kernel {
template <typename T>
void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
node_ = kernel_node;
CheckParam(kernel_node);
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS));
@ -28,27 +29,28 @@ void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_1_shape.size());
}
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num_; i++) {
auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_);
input_flat_shape_list_.push_back(flat_shape);
}
}
template <typename T>
bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
size_t input_num = AnfAlgo::GetInputTensorNum(node_);
std::vector<std::vector<size_t>> input_flat_shape_list;
for (size_t i = 0; i < input_num; i++) {
auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i);
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_);
input_flat_shape_list.push_back(flat_shape);
}
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto buff_size = outputs[0]->size;
// each input's row of shape after flat are same
auto before_axis = input_flat_shape_list_[0][0];
auto before_axis = input_flat_shape_list[0][0];
for (size_t i = 0; i < before_axis; ++i) {
for (size_t j = 0; j < input_num_; ++j) {
for (size_t j = 0; j < input_num; ++j) {
auto input_j_addr = reinterpret_cast<T *>(inputs[j]->addr);
auto copy_num = input_flat_shape_list_[j][1];
auto copy_num = input_flat_shape_list[j][1];
auto offset = copy_num * i;
auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T));
if (ret != EOK) {

View File

@ -36,8 +36,7 @@ class ConcatCPUKernel : public CPUKernel {
private:
void CheckParam(const CNodePtr &kernel_node);
int axis_ = 0;
size_t input_num_ = 1;
std::vector<std::vector<size_t>> input_flat_shape_list_;
CNodePtr node_ = nullptr;
};
MS_REG_CPU_KERNEL_T(

View File

@ -140,7 +140,7 @@ void OpTilingCalculater::Init() {
tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf();
MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size();
for (const auto &iter : tiling_func_map_) {
MS_LOG(INFO) << "Regist tiling func:" << iter.first;
MS_LOG(INFO) << "Register tiling func:" << iter.first;
}
}
@ -150,6 +150,7 @@ std::string GetRealOpType(const std::string &op_type) {
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
{"SparseGatherV2", "GatherV2"},
{"Pad", "PadD"},
{"Concat", "ConcatD"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {

View File

@ -30,6 +30,7 @@
namespace mindspore {
// op name. Op which not exists in operator/ops.h, so define it's name here
constexpr auto kConcatOpName = "Concat";
constexpr auto kUniqueOpName = "Unique";
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
@ -492,7 +493,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName,
kTransposeOpName, kReduceSumOpName, kConcatOpName};
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
try {

View File

@ -287,6 +287,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

View File

@ -954,6 +954,90 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
return std::make_shared<AbstractTensor>(kBool, output_shape);
}
AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
const std::string op_name = primitive->name();
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list is empty.";
}
AbstractTuplePtr arg = nullptr;
AbstractTensorPtr tensor_base = nullptr;
size_t tuple_len = 0;
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
if (args_spec_list[0]->isa<AbstractTuple>()) {
CheckArgsSize(op_name, args_spec_list, 1);
arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
tuple_len = arg->elements().size();
tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
tuple_len = args_spec_list.size();
tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
}
MS_EXCEPTION_IF_NULL(tensor_base);
ShapeVector shape_base = tensor_base->shape()->shape();
int64_t rank_base = SizeToLong(shape_base.size());
ShapeVector min_shape_base = tensor_base->shape()->min_shape();
ShapeVector max_shape_base = tensor_base->shape()->max_shape();
(void)CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base);
primitive->set_attr("T", tensor_base->element()->BuildType());
primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len)));
ValuePtr axis = primitive->GetAttr("axis");
// Axis value should be in [-(rank_base + 1), rank_base).
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
// If axis is negative, add offset(rank_base) to turn it to positive.
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
int64_t all_shp = shape_base[axis_value];
int64_t min_all_shp = min_shape_base[axis_value];
int64_t max_all_shp = max_shape_base[axis_value];
for (size_t i = 1; i < tuple_len; ++i) {
AbstractTensorPtr tensor = nullptr;
if (args_spec_list[0]->isa<AbstractTuple>()) {
tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
}
ShapeVector shape_tensor = tensor->shape()->shape();
int64_t rank_tensor = SizeToLong(shape_tensor.size());
ShapeVector min_shape_tensor = tensor->shape()->min_shape();
ShapeVector max_shape_tensor = tensor->shape()->max_shape();
(void)CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor);
(void)CheckDtypeSame(op_name, tensor_base, tensor);
if (rank_tensor != rank_base) {
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank";
}
for (int j = 0; j < rank_base; ++j) {
if (j != axis_value && shape_tensor[j] != shape_base[j]) {
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size";
}
}
if (all_shp == -1 || shape_base[axis_value] == -1) {
all_shp = -1;
} else {
all_shp += shape_tensor[axis_value];
}
min_all_shp += min_shape_tensor[axis_value];
max_all_shp += max_shape_tensor[axis_value];
}
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
MS_EXCEPTION_IF_NULL(ret);
auto shape = ret->shape()->shape();
auto min_shape = ret->shape()->min_shape();
auto max_shape = ret->shape()->max_shape();
(void)CheckMinMaxShape(shape, &min_shape, &max_shape);
shape[axis_value] = all_shp;
min_shape[axis_value] = min_all_shp;
max_shape[axis_value] = max_all_shp;
ret->set_shape(std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
}
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();

View File

@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
{prim::kPrimConcat, {InferImplConcat, true}},
{prim::kPrimRange, {InferImplRange, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},

View File

@ -170,6 +170,7 @@ from .minimum_ds import _minimum_ds_tbe
from .minimum_grad import _minimum_grad_tbe
from .maximum_grad import _maximum_grad_tbe
from .concat import _concat_tbe
from .concat_ds import _concat_ds_tbe
from .slice import _slice_tbe
from .sign import _sign_tbe
from .greater import _greater_tbe

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Concat op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
concat_ds_op_info = TBERegOp("Concat") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("concat_d.so") \
.compute_cost(10) \
.kernel_name("concat_d") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("axis", "required", "int", "all") \
.input(0, "input_values", False, "dynamic", "all") \
.output(0, "output_data", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(concat_ds_op_info)
def _concat_ds_tbe():
"""Concat TBE register"""
return

View File

@ -2148,6 +2148,19 @@ class Concat(PrimitiveWithInfer):
out = {'shape': ret_shp,
'dtype': x_type[0],
'value': value}
if -1 in x_shp[0]:
x_min_shp = input_x['min_shape']
ret_min_shp = x_min_shp[0].copy()
ret_min_shp[axis] = 0
for all_min_shp in x_min_shp:
ret_min_shp[axis] += all_min_shp[axis]
out['min_shape'] = ret_min_shp
x_max_shp = input_x['max_shape']
ret_max_shp = x_max_shp[0].copy()
ret_max_shp[axis] = 0
for all_max_shp in x_max_shp:
ret_max_shp[axis] += all_max_shp[axis]
out['max_shape'] = ret_max_shp
return out
@ -2789,7 +2802,7 @@ class StridedSlice(PrimitiveWithInfer):
if has_ellipsis:
# When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
j += 1
i += ellipsis_occupied_dims
@ -3985,7 +3998,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
offset = 1
for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1]
self.paddings[i][1]
if padded % self.block_shape[i] != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_shape[{i}] {self.block_shape[i]}')

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, axis=0):
super(Net, self).__init__()
self.unique = P.Unique()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=axis)
def construct(self, x1, x2):
out1_unique, _ = self.unique(x1)
out2_unique, _ = self.unique(x2)
out1_shape = self.reshape(out1_unique, (1, -1, 2))
out2_shape = self.reshape(out2_unique, (1, -1, 2))
return self.concat((out1_shape, out2_shape))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_concat():
x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32)
x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32)
net = Net(axis=1)
output = net(x1, x2)
expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]])
assert (output.asnumpy() == expect).all()

View File

@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, axis=0):
super(Net, self).__init__()
self.unique = P.Unique()
self.reshape = P.Reshape()
self.concat = P.Concat(axis=axis)
def construct(self, x1, x2):
out1_unique, _ = self.unique(x1)
out2_unique, _ = self.unique(x2)
out1_shape = self.reshape(out1_unique, (1, -1, 2))
out2_shape = self.reshape(out2_unique, (1, -1, 2))
return self.concat((out1_shape, out2_shape))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dynamic_concat_cpu():
x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32)
x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32)
net = Net(axis=1)
output = net(x1, x2)
expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]])
assert (output.asnumpy() == expect).all()

View File

@ -835,14 +835,14 @@ def test_mixed_precision_cast():
assert z.dtype == mstype.float16
def test_while_concat():
def test_while_add():
class Net(nn.Cell):
def __init__(self, data):
super(Net, self).__init__()
self.start = Tensor(0, dtype=mstype.int32)
self.end = Tensor(2, dtype=mstype.int32)
self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
self.concat = P.Concat()
self.add = P.TensorAdd()
def construct(self, inputs):
idx = self.start
@ -850,7 +850,7 @@ def test_while_concat():
out = self.out
while idx < end:
xi = inputs[idx, :, :]
out = self.concat((out, xi))
out = self.add(out, xi)
idx = idx + 1
return out