[feat] [assistant] Add new aicpu operator ParallelConcat

This commit is contained in:
王振邦 2022-09-01 12:53:26 +08:00
parent 3aeaab22bf
commit d6ee2e99c7
9 changed files with 525 additions and 29 deletions

View File

@ -0,0 +1,154 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int axis = 0;
constexpr size_t kParallelConcatOutputsNum = 1;
} // namespace
bool ParallelConcatCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the data type of input must be float or double, but got: " << kernel_attr << ".";
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
int ParallelConcatCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
ResetResource();
std::vector<int64_t> output_shape = outputs[0]->GetShapeVector();
size_t output_elements = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
if (output_elements == 0) {
is_null_input_ = true;
}
input_num_ = inputs.size();
auto x_shape = inputs[0]->GetShapeVector();
for (size_t i = 0; i < input_num_; i++) {
if (x_shape != inputs[i]->GetShapeVector()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of all tensors must be the same, but got tensor0.shape "
<< x_shape << " and tensor" << i << ".shape " << inputs[i]->GetShapeVector();
}
}
input_flat_shape_list_.reserve(input_num_);
for (size_t i = 0; i < input_num_; i++) {
auto input_shape_i = inputs[i]->GetShapeVector();
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis);
(void)input_flat_shape_list_.emplace_back(flat_shape);
}
return KRET_OK;
}
template <typename T>
bool ParallelConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
std::vector<T *> input_addr_list;
for (size_t j = 0; j < input_num_; ++j) {
auto *tmp_addr = reinterpret_cast<T *>(inputs[j]->addr);
(void)input_addr_list.emplace_back(tmp_addr);
}
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t output_dim_1 = 0;
for (size_t j = 0; j < input_num_; ++j) {
output_dim_1 += LongToSize(input_flat_shape_list_[j][1]);
}
// each input's row of shape after flat are same
auto before_axis = LongToSize(input_flat_shape_list_[0][0]);
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
auto output_ptr = output_addr + i * output_dim_1;
for (size_t j = 0; j < input_num_; ++j) {
if (input_flat_shape_list_[j][1] == 0) {
continue;
}
auto copy_num = LongToSize(input_flat_shape_list_[j][1]);
auto copy_size = copy_num * sizeof(T);
auto offset = copy_num * i;
auto ret = memcpy_s(output_ptr, copy_size, input_addr_list[j] + offset, copy_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret;
}
output_ptr += copy_num;
}
}
};
ParallelLaunchAutoSearch(task, before_axis, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, ParallelConcatCpuKernelMod::PCFunc>> ParallelConcatCpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&ParallelConcatCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&ParallelConcatCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&ParallelConcatCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&ParallelConcatCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&ParallelConcatCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&ParallelConcatCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&ParallelConcatCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&ParallelConcatCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&ParallelConcatCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ParallelConcatCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&ParallelConcatCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&ParallelConcatCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&ParallelConcatCpuKernelMod::LaunchKernel<complex128>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&ParallelConcatCpuKernelMod::LaunchKernel<bool>}};
std::vector<KernelAttr> ParallelConcatCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, PCFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ParallelConcat, ParallelConcatCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include <complex>
#include <map>
#include <functional>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
class ParallelConcatCpuKernelMod : public NativeCpuKernelMod {
public:
ParallelConcatCpuKernelMod() = default;
~ParallelConcatCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
input_num_ = 0;
input_flat_shape_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using PCFunc = std::function<bool(ParallelConcatCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
private:
size_t input_num_;
PCFunc kernel_func_;
bool is_null_input_{false};
std::vector<ShapeVector> input_flat_shape_list_;
static std::vector<std::pair<KernelAttr, PCFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_

View File

@ -182,6 +182,7 @@ constexpr auto kOnes = "Ones";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kIdentity = "Identity";
constexpr auto kConcat = "Concat";
constexpr auto kParallelConcat = "ParallelConcat";
constexpr auto kFlattenConcat = "FlattenConcat";
constexpr auto kRightShift = "RightShift";
constexpr auto kDiag = "Diag";
@ -500,6 +501,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
GVAR_DEF(PrimitivePtr, kPrimParallelConcat, std::make_shared<Primitive>(kParallelConcat));
GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared<Primitive>(kFlattenConcat));
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));

View File

@ -0,0 +1,121 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include "ops/parallel_concat.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
const int64_t kOneNum = 1;
abstract::ShapePtr ParallelConcatInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
prim_name);
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
(void)primitive->AddAttr("inputNums", MakeValue(SizeToLong(elements.size())));
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(element0);
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
auto element0_rank = element0_shape.size();
if (element0_rank < kOneNum) {
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the rank of input must greater than 1. But got"
<< element0_rank << ".";
}
auto axis = 0;
int64_t all_shp = element0_shape[axis];
for (size_t i = 0; i < elements.size(); ++i) {
auto shape_i = elements[i]->BuildShape();
if (shape_i->IsDynamic()) {
auto ret_shape = element0_shape;
ret_shape[axis] = -1;
return std::make_shared<abstract::Shape>(ret_shape);
}
}
for (size_t i = 1; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("x" + std::to_string(i) + ".shape[0]", elementi_shape[0], kEqual, kOneNum,
prim_name);
if (elementi_shape.size() != element0_shape.size()) {
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the rank of all elements should be the same, but got "
<< "x0.rank [" << element0_shape.size() << "] and x" << std::to_string(i) << ".rank ["
<< elementi_shape.size() << "].";
}
for (size_t j = 1; j < element0_rank; ++j) {
if (elementi_shape[j] != element0_shape[j]) {
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the shape of all elements should be the same, but got "
<< "x0.shape[" << std::to_string(j) << "] = [" << element0_shape[j] << "] and x"
<< std::to_string(i) << ".shape[" << std::to_string(j) << "] = [" << elementi_shape[j]
<< "].";
}
}
all_shp = all_shp + elementi_shape[axis];
}
auto ret_shape = element0_shape;
ret_shape[axis] = all_shp;
return std::make_shared<abstract::Shape>(ret_shape);
}
TypePtr ParallelConcatInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the input must be a list or tuple of tensors. But got"
<< input_args[0]->ToString() << ".";
}
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
prim_name);
std::map<std::string, TypePtr> types;
for (size_t i = 0; i < elements.size(); ++i) {
std::string elementi = "element" + std::to_string(i);
(void)types.emplace(elementi, elements[i]->BuildType());
}
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim_name);
return elements[0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(ParallelConcat, BaseOperator);
AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = ParallelConcatInferType(primitive, input_args);
auto infer_shape = ParallelConcatInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ParallelConcat, prim::kPrimParallelConcat, ParallelConcatInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_
#define MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameParallelConcat = "ParallelConcat";
/// \brief Connect tensor in the specified axis.
/// Refer to Python API @ref mindspore.ops.ParallelConcat for more details.
class MIND_API ParallelConcat : public BaseOperator {
public:
MIND_API_BASE_MEMBER(ParallelConcat);
/// \brief Constructor.
ParallelConcat() : BaseOperator(kNameParallelConcat) { InitIOName({"x"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ParallelConcat for the inputs.
void Init() const {}
};
abstract::AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_

View File

@ -161,3 +161,4 @@ from .reservoir_replay_buffer import _rrb_create_op_cpu
from .reservoir_replay_buffer import _rrb_push_op_cpu
from .reservoir_replay_buffer import _rrb_sample_op_cpu
from .reservoir_replay_buffer import _rrb_destroy_op_cpu
from .parallel_concat import _parallel_concat_aicpu

View File

@ -0,0 +1,42 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ParallelConcat op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
parallel_concat_op_info = AiCPURegOp("ParallelConcat") \
.fusion_type("OPAQUE") \
.input(0, "x", "dynamic") \
.output(0, "y", "required") \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(parallel_concat_op_info)
def _parallel_concat_aicpu():
"""ParallelConcat AiCPU register"""
return

View File

@ -2700,7 +2700,7 @@ class ConcatOffsetV1(Primitive):
"""Initialize ConcatOffsetV1"""
class ParallelConcat(PrimitiveWithInfer):
class ParallelConcat(Primitive):
r"""
Concats tensor in the first dimension.
@ -2716,17 +2716,21 @@ class ParallelConcat(PrimitiveWithInfer):
Inputs:
- **values** (tuple, list) - A tuple or a list of input tensors. The data type and shape of these
tensors must be the same. The data type is Number except float64.
tensors must be the same. The supported date type is Number on CPU, the same for Ascend except
[float64, complex64, complex128].
Outputs:
Tensor, data type is the same as `values`.
Raises:
TypeError: If any type of the inputs is not a Tensor.
TypeError: If the data type of these tensors are not the same.
ValueError: If any tensor.shape[0] is not 1.
ValueError: If length of shape of `values` is less than 1.
ValueError: The data type and shape of these tensors are not the same.
ValueError: If the shape of these tensors are not the same.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
@ -2742,31 +2746,6 @@ class ParallelConcat(PrimitiveWithInfer):
def __init__(self):
"""Initialize ParallelConcat"""
def __infer__(self, values):
x_shp = values['shape']
x_type = values['dtype']
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
first_elem = x_shp[0]
for i, elem in enumerate(x_shp[1:]):
j = i + 1
validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
ret_shp = x_shp[0].copy()
ret_shp[0] = len(x_shp)
self.add_prim_attr('shape', ret_shp)
self.add_prim_attr('N', len(x_shp))
out = {'shape': ret_shp,
'dtype': x_type[0],
'value': None}
return out
def _get_stack_shape(value, x_shape, x_type, axis, prim_name):
"""for stack output shape"""

View File

@ -0,0 +1,83 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.array_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class ParallelConcatNet(nn.Cell):
def __init__(self):
super(ParallelConcatNet, self).__init__()
self.net = P.ParallelConcat()
@ms_function
def construct(self, inputs):
return self.net(inputs)
def parallel_concat(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float32))
data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float32))
inputs = [data1, data2]
net_ms = ParallelConcatNet()
out_ms = net_ms(inputs)
expected = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[9, 10, 11, 12],
[13, 14, 15, 16]]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def parallel_concat_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float64))
data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float64))
inputs = [data1, data2]
net_ms = ParallelConcatNet()
out_ms = net_ms(inputs)
expected = np.array([[[1, 2, 3, 4],
[5, 6, 7, 8]],
[[9, 10, 11, 12],
[13, 14, 15, 16]]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for ParallelConcat
Expectation: the result match to tensorflow
"""
parallel_concat(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for ParallelConcat
Expectation: the result match to tensorflow
"""
parallel_concat_pynative(loss=1.0e-5)