forked from mindspore-Ecosystem/mindspore
Add dynamic shape support for the operator Concat
This commit is contained in:
parent
a5cde3fea3
commit
8241dfa443
|
@ -21,6 +21,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
template <typename T>
|
||||
void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
node_ = kernel_node;
|
||||
CheckParam(kernel_node);
|
||||
|
||||
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS));
|
||||
|
@ -28,27 +29,28 @@ void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_1_shape.size());
|
||||
}
|
||||
|
||||
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num_; i++) {
|
||||
auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_);
|
||||
input_flat_shape_list_.push_back(flat_shape);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(node_);
|
||||
std::vector<std::vector<size_t>> input_flat_shape_list;
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i);
|
||||
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_);
|
||||
input_flat_shape_list.push_back(flat_shape);
|
||||
}
|
||||
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto buff_size = outputs[0]->size;
|
||||
// each input's row of shape after flat are same
|
||||
auto before_axis = input_flat_shape_list_[0][0];
|
||||
auto before_axis = input_flat_shape_list[0][0];
|
||||
for (size_t i = 0; i < before_axis; ++i) {
|
||||
for (size_t j = 0; j < input_num_; ++j) {
|
||||
for (size_t j = 0; j < input_num; ++j) {
|
||||
auto input_j_addr = reinterpret_cast<T *>(inputs[j]->addr);
|
||||
auto copy_num = input_flat_shape_list_[j][1];
|
||||
auto copy_num = input_flat_shape_list[j][1];
|
||||
auto offset = copy_num * i;
|
||||
auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
|
|
|
@ -36,8 +36,7 @@ class ConcatCPUKernel : public CPUKernel {
|
|||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
int axis_ = 0;
|
||||
size_t input_num_ = 1;
|
||||
std::vector<std::vector<size_t>> input_flat_shape_list_;
|
||||
CNodePtr node_ = nullptr;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
|
|
|
@ -140,7 +140,7 @@ void OpTilingCalculater::Init() {
|
|||
tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf();
|
||||
MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size();
|
||||
for (const auto &iter : tiling_func_map_) {
|
||||
MS_LOG(INFO) << "Regist tiling func:" << iter.first;
|
||||
MS_LOG(INFO) << "Register tiling func:" << iter.first;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,6 +150,7 @@ std::string GetRealOpType(const std::string &op_type) {
|
|||
{"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"},
|
||||
{"SparseGatherV2", "GatherV2"},
|
||||
{"Pad", "PadD"},
|
||||
{"Concat", "ConcatD"},
|
||||
};
|
||||
auto iter = kOpTypeMap.find(op_type);
|
||||
if (iter == kOpTypeMap.end()) {
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
// op name. Op which not exists in operator/ops.h, so define it's name here
|
||||
constexpr auto kConcatOpName = "Concat";
|
||||
constexpr auto kUniqueOpName = "Unique";
|
||||
constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits";
|
||||
constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder";
|
||||
|
@ -492,7 +493,8 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH
|
|||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
|
||||
const std::set<std::string> DynamicShapeConstInputToAttr = {
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName,
|
||||
kTransposeOpName, kReduceSumOpName, kConcatOpName};
|
||||
|
||||
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
|
||||
try {
|
||||
|
|
|
@ -287,6 +287,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
|
|
|
@ -954,6 +954,90 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
|
|||
return std::make_shared<AbstractTensor>(kBool, output_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const std::string op_name = primitive->name();
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list is empty.";
|
||||
}
|
||||
|
||||
AbstractTuplePtr arg = nullptr;
|
||||
AbstractTensorPtr tensor_base = nullptr;
|
||||
size_t tuple_len = 0;
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
if (args_spec_list[0]->isa<AbstractTuple>()) {
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
arg = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
tuple_len = arg->elements().size();
|
||||
tensor_base = CheckArg<AbstractTensor>(op_name, arg->elements(), 0);
|
||||
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
|
||||
tuple_len = args_spec_list.size();
|
||||
tensor_base = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(tensor_base);
|
||||
ShapeVector shape_base = tensor_base->shape()->shape();
|
||||
int64_t rank_base = SizeToLong(shape_base.size());
|
||||
ShapeVector min_shape_base = tensor_base->shape()->min_shape();
|
||||
ShapeVector max_shape_base = tensor_base->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base);
|
||||
|
||||
primitive->set_attr("T", tensor_base->element()->BuildType());
|
||||
primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len)));
|
||||
|
||||
ValuePtr axis = primitive->GetAttr("axis");
|
||||
// Axis value should be in [-(rank_base + 1), rank_base).
|
||||
int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base);
|
||||
// If axis is negative, add offset(rank_base) to turn it to positive.
|
||||
axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base));
|
||||
|
||||
int64_t all_shp = shape_base[axis_value];
|
||||
int64_t min_all_shp = min_shape_base[axis_value];
|
||||
int64_t max_all_shp = max_shape_base[axis_value];
|
||||
for (size_t i = 1; i < tuple_len; ++i) {
|
||||
AbstractTensorPtr tensor = nullptr;
|
||||
if (args_spec_list[0]->isa<AbstractTuple>()) {
|
||||
tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
|
||||
} else if (args_spec_list[0]->isa<AbstractTensor>()) {
|
||||
tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
|
||||
}
|
||||
ShapeVector shape_tensor = tensor->shape()->shape();
|
||||
int64_t rank_tensor = SizeToLong(shape_tensor.size());
|
||||
ShapeVector min_shape_tensor = tensor->shape()->min_shape();
|
||||
ShapeVector max_shape_tensor = tensor->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor);
|
||||
(void)CheckDtypeSame(op_name, tensor_base, tensor);
|
||||
if (rank_tensor != rank_base) {
|
||||
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank";
|
||||
}
|
||||
for (int j = 0; j < rank_base; ++j) {
|
||||
if (j != axis_value && shape_tensor[j] != shape_base[j]) {
|
||||
MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size";
|
||||
}
|
||||
}
|
||||
if (all_shp == -1 || shape_base[axis_value] == -1) {
|
||||
all_shp = -1;
|
||||
} else {
|
||||
all_shp += shape_tensor[axis_value];
|
||||
}
|
||||
min_all_shp += min_shape_tensor[axis_value];
|
||||
max_all_shp += max_shape_tensor[axis_value];
|
||||
}
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(tensor_base->Broaden());
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto shape = ret->shape()->shape();
|
||||
auto min_shape = ret->shape()->min_shape();
|
||||
auto max_shape = ret->shape()->max_shape();
|
||||
(void)CheckMinMaxShape(shape, &min_shape, &max_shape);
|
||||
shape[axis_value] = all_shp;
|
||||
min_shape[axis_value] = min_all_shp;
|
||||
max_shape[axis_value] = max_all_shp;
|
||||
ret->set_shape(std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
|
|
|
@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, true}},
|
||||
{prim::kPrimRange, {InferImplRange, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
|
|
|
@ -170,6 +170,7 @@ from .minimum_ds import _minimum_ds_tbe
|
|||
from .minimum_grad import _minimum_grad_tbe
|
||||
from .maximum_grad import _maximum_grad_tbe
|
||||
from .concat import _concat_tbe
|
||||
from .concat_ds import _concat_ds_tbe
|
||||
from .slice import _slice_tbe
|
||||
from .sign import _sign_tbe
|
||||
from .greater import _greater_tbe
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Concat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
concat_ds_op_info = TBERegOp("Concat") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("concat_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("concat_d") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("axis", "required", "int", "all") \
|
||||
.input(0, "input_values", False, "dynamic", "all") \
|
||||
.output(0, "output_data", False, "required", "all") \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(concat_ds_op_info)
|
||||
def _concat_ds_tbe():
|
||||
"""Concat TBE register"""
|
||||
return
|
|
@ -2148,6 +2148,19 @@ class Concat(PrimitiveWithInfer):
|
|||
out = {'shape': ret_shp,
|
||||
'dtype': x_type[0],
|
||||
'value': value}
|
||||
if -1 in x_shp[0]:
|
||||
x_min_shp = input_x['min_shape']
|
||||
ret_min_shp = x_min_shp[0].copy()
|
||||
ret_min_shp[axis] = 0
|
||||
for all_min_shp in x_min_shp:
|
||||
ret_min_shp[axis] += all_min_shp[axis]
|
||||
out['min_shape'] = ret_min_shp
|
||||
x_max_shp = input_x['max_shape']
|
||||
ret_max_shp = x_max_shp[0].copy()
|
||||
ret_max_shp[axis] = 0
|
||||
for all_max_shp in x_max_shp:
|
||||
ret_max_shp[axis] += all_max_shp[axis]
|
||||
out['max_shape'] = ret_max_shp
|
||||
return out
|
||||
|
||||
|
||||
|
@ -2789,7 +2802,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if has_ellipsis:
|
||||
# When there is ellipsis, handle the second half of the ellipsis split.
|
||||
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
|
||||
j += 1
|
||||
i += ellipsis_occupied_dims
|
||||
|
@ -3985,7 +3998,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
offset = 1
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_shape[i] != 0:
|
||||
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||
f'block_shape[{i}] {self.block_shape[i]}')
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super(Net, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=axis)
|
||||
|
||||
def construct(self, x1, x2):
|
||||
out1_unique, _ = self.unique(x1)
|
||||
out2_unique, _ = self.unique(x2)
|
||||
out1_shape = self.reshape(out1_unique, (1, -1, 2))
|
||||
out2_shape = self.reshape(out2_unique, (1, -1, 2))
|
||||
return self.concat((out1_shape, out2_shape))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_concat():
|
||||
x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32)
|
||||
x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32)
|
||||
net = Net(axis=1)
|
||||
output = net(x1, x2)
|
||||
expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]])
|
||||
assert (output.asnumpy() == expect).all()
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super(Net, self).__init__()
|
||||
self.unique = P.Unique()
|
||||
self.reshape = P.Reshape()
|
||||
self.concat = P.Concat(axis=axis)
|
||||
|
||||
def construct(self, x1, x2):
|
||||
out1_unique, _ = self.unique(x1)
|
||||
out2_unique, _ = self.unique(x2)
|
||||
out1_shape = self.reshape(out1_unique, (1, -1, 2))
|
||||
out2_shape = self.reshape(out2_unique, (1, -1, 2))
|
||||
return self.concat((out1_shape, out2_shape))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_concat_cpu():
|
||||
x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32)
|
||||
x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32)
|
||||
net = Net(axis=1)
|
||||
output = net(x1, x2)
|
||||
expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]])
|
||||
assert (output.asnumpy() == expect).all()
|
|
@ -835,14 +835,14 @@ def test_mixed_precision_cast():
|
|||
assert z.dtype == mstype.float16
|
||||
|
||||
|
||||
def test_while_concat():
|
||||
def test_while_add():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, data):
|
||||
super(Net, self).__init__()
|
||||
self.start = Tensor(0, dtype=mstype.int32)
|
||||
self.end = Tensor(2, dtype=mstype.int32)
|
||||
self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
|
||||
self.concat = P.Concat()
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, inputs):
|
||||
idx = self.start
|
||||
|
@ -850,7 +850,7 @@ def test_while_concat():
|
|||
out = self.out
|
||||
while idx < end:
|
||||
xi = inputs[idx, :, :]
|
||||
out = self.concat((out, xi))
|
||||
out = self.add(out, xi)
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue