[feat] [assistant] Add new aicpu operator ParallelConcat
This commit is contained in:
parent
3aeaab22bf
commit
d6ee2e99c7
|
@ -0,0 +1,154 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/parallel_concat_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int axis = 0;
|
||||
constexpr size_t kParallelConcatOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
bool ParallelConcatCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the data type of input must be float or double, but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int ParallelConcatCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> output_shape = outputs[0]->GetShapeVector();
|
||||
size_t output_elements = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (output_elements == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
input_num_ = inputs.size();
|
||||
auto x_shape = inputs[0]->GetShapeVector();
|
||||
for (size_t i = 0; i < input_num_; i++) {
|
||||
if (x_shape != inputs[i]->GetShapeVector()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the shape of all tensors must be the same, but got tensor0.shape "
|
||||
<< x_shape << " and tensor" << i << ".shape " << inputs[i]->GetShapeVector();
|
||||
}
|
||||
}
|
||||
input_flat_shape_list_.reserve(input_num_);
|
||||
for (size_t i = 0; i < input_num_; i++) {
|
||||
auto input_shape_i = inputs[i]->GetShapeVector();
|
||||
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis);
|
||||
(void)input_flat_shape_list_.emplace_back(flat_shape);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ParallelConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
std::vector<T *> input_addr_list;
|
||||
for (size_t j = 0; j < input_num_; ++j) {
|
||||
auto *tmp_addr = reinterpret_cast<T *>(inputs[j]->addr);
|
||||
(void)input_addr_list.emplace_back(tmp_addr);
|
||||
}
|
||||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t output_dim_1 = 0;
|
||||
for (size_t j = 0; j < input_num_; ++j) {
|
||||
output_dim_1 += LongToSize(input_flat_shape_list_[j][1]);
|
||||
}
|
||||
|
||||
// each input's row of shape after flat are same
|
||||
auto before_axis = LongToSize(input_flat_shape_list_[0][0]);
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
auto output_ptr = output_addr + i * output_dim_1;
|
||||
for (size_t j = 0; j < input_num_; ++j) {
|
||||
if (input_flat_shape_list_[j][1] == 0) {
|
||||
continue;
|
||||
}
|
||||
auto copy_num = LongToSize(input_flat_shape_list_[j][1]);
|
||||
auto copy_size = copy_num * sizeof(T);
|
||||
auto offset = copy_num * i;
|
||||
auto ret = memcpy_s(output_ptr, copy_size, input_addr_list[j] + offset, copy_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret;
|
||||
}
|
||||
output_ptr += copy_num;
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, before_axis, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ParallelConcatCpuKernelMod::PCFunc>> ParallelConcatCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<complex128>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&ParallelConcatCpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
std::vector<KernelAttr> ParallelConcatCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, PCFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ParallelConcat, ParallelConcatCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
class ParallelConcatCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ParallelConcatCpuKernelMod() = default;
|
||||
~ParallelConcatCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
input_num_ = 0;
|
||||
input_flat_shape_list_.clear();
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using PCFunc = std::function<bool(ParallelConcatCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t input_num_;
|
||||
PCFunc kernel_func_;
|
||||
bool is_null_input_{false};
|
||||
std::vector<ShapeVector> input_flat_shape_list_;
|
||||
static std::vector<std::pair<KernelAttr, PCFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PARALLEL_CONCAT_CPU_KERNEL_H_
|
|
@ -182,6 +182,7 @@ constexpr auto kOnes = "Ones";
|
|||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kIdentity = "Identity";
|
||||
constexpr auto kConcat = "Concat";
|
||||
constexpr auto kParallelConcat = "ParallelConcat";
|
||||
constexpr auto kFlattenConcat = "FlattenConcat";
|
||||
constexpr auto kRightShift = "RightShift";
|
||||
constexpr auto kDiag = "Diag";
|
||||
|
@ -500,6 +501,7 @@ GVAR_DEF(PrimitivePtr, kPrimArrayMap, std::make_shared<Primitive>("array_map"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimArrayReduce, std::make_shared<Primitive>("array_reduce"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCast, std::make_shared<Primitive>("Cast"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcat, std::make_shared<Primitive>(kConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimParallelConcat, std::make_shared<Primitive>(kParallelConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFlattenConcat, std::make_shared<Primitive>(kFlattenConcat));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSqueeze, std::make_shared<Primitive>("Squeeze"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsqueeze, std::make_shared<Primitive>("Unsqueeze"));
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "ops/parallel_concat.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const int64_t kOneNum = 1;
|
||||
abstract::ShapePtr ParallelConcatInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
|
||||
prim_name);
|
||||
(void)primitive->AddAttr("N", MakeValue(SizeToLong(elements.size())));
|
||||
(void)primitive->AddAttr("inputNums", MakeValue(SizeToLong(elements.size())));
|
||||
auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(element0);
|
||||
auto element0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kShape];
|
||||
auto element0_rank = element0_shape.size();
|
||||
if (element0_rank < kOneNum) {
|
||||
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the rank of input must greater than 1. But got:"
|
||||
<< element0_rank << ".";
|
||||
}
|
||||
|
||||
auto axis = 0;
|
||||
int64_t all_shp = element0_shape[axis];
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
auto shape_i = elements[i]->BuildShape();
|
||||
if (shape_i->IsDynamic()) {
|
||||
auto ret_shape = element0_shape;
|
||||
ret_shape[axis] = -1;
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
}
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
auto elementi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x" + std::to_string(i) + ".shape[0]", elementi_shape[0], kEqual, kOneNum,
|
||||
prim_name);
|
||||
if (elementi_shape.size() != element0_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the rank of all elements should be the same, but got "
|
||||
<< "x0.rank [" << element0_shape.size() << "] and x" << std::to_string(i) << ".rank ["
|
||||
<< elementi_shape.size() << "].";
|
||||
}
|
||||
for (size_t j = 1; j < element0_rank; ++j) {
|
||||
if (elementi_shape[j] != element0_shape[j]) {
|
||||
MS_EXCEPTION(ValueError) << "For [" << prim_name << "], the shape of all elements should be the same, but got "
|
||||
<< "x0.shape[" << std::to_string(j) << "] = [" << element0_shape[j] << "] and x"
|
||||
<< std::to_string(i) << ".shape[" << std::to_string(j) << "] = [" << elementi_shape[j]
|
||||
<< "].";
|
||||
}
|
||||
}
|
||||
|
||||
all_shp = all_shp + elementi_shape[axis];
|
||||
}
|
||||
auto ret_shape = element0_shape;
|
||||
ret_shape[axis] = all_shp;
|
||||
return std::make_shared<abstract::Shape>(ret_shape);
|
||||
}
|
||||
|
||||
TypePtr ParallelConcatInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
if (!input_args[0]->isa<abstract::AbstractTuple>() && !input_args[0]->isa<abstract::AbstractList>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the input must be a list or tuple of tensors. But got:"
|
||||
<< input_args[0]->ToString() << ".";
|
||||
}
|
||||
auto elements = input_args[0]->isa<abstract::AbstractTuple>()
|
||||
? input_args[0]->cast<abstract::AbstractTuplePtr>()->elements()
|
||||
: input_args[0]->cast<abstract::AbstractListPtr>()->elements();
|
||||
(void)CheckAndConvertUtils::CheckInteger("concat element num", SizeToLong(elements.size()), kGreaterEqual, kOneNum,
|
||||
prim_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
std::string elementi = "element" + std::to_string(i);
|
||||
(void)types.emplace(elementi, elements[i]->BuildType());
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex_and_bool, prim_name);
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(ParallelConcat, BaseOperator);
|
||||
AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t kInputNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = ParallelConcatInferType(primitive, input_args);
|
||||
auto infer_shape = ParallelConcatInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ParallelConcat, prim::kPrimParallelConcat, ParallelConcatInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_
|
||||
#define MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameParallelConcat = "ParallelConcat";
|
||||
/// \brief Connect tensor in the specified axis.
|
||||
/// Refer to Python API @ref mindspore.ops.ParallelConcat for more details.
|
||||
class MIND_API ParallelConcat : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(ParallelConcat);
|
||||
/// \brief Constructor.
|
||||
ParallelConcat() : BaseOperator(kNameParallelConcat) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.ParallelConcat for the inputs.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr ParallelConcatInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_PARALLEL_PARALLEL_CONCAT_H_
|
|
@ -161,3 +161,4 @@ from .reservoir_replay_buffer import _rrb_create_op_cpu
|
|||
from .reservoir_replay_buffer import _rrb_push_op_cpu
|
||||
from .reservoir_replay_buffer import _rrb_sample_op_cpu
|
||||
from .reservoir_replay_buffer import _rrb_destroy_op_cpu
|
||||
from .parallel_concat import _parallel_concat_aicpu
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ParallelConcat op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
parallel_concat_op_info = AiCPURegOp("ParallelConcat") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "dynamic") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(parallel_concat_op_info)
|
||||
def _parallel_concat_aicpu():
|
||||
"""ParallelConcat AiCPU register"""
|
||||
return
|
|
@ -2700,7 +2700,7 @@ class ConcatOffsetV1(Primitive):
|
|||
"""Initialize ConcatOffsetV1"""
|
||||
|
||||
|
||||
class ParallelConcat(PrimitiveWithInfer):
|
||||
class ParallelConcat(Primitive):
|
||||
r"""
|
||||
Concats tensor in the first dimension.
|
||||
|
||||
|
@ -2716,17 +2716,21 @@ class ParallelConcat(PrimitiveWithInfer):
|
|||
|
||||
Inputs:
|
||||
- **values** (tuple, list) - A tuple or a list of input tensors. The data type and shape of these
|
||||
tensors must be the same. The data type is Number except float64.
|
||||
tensors must be the same. The supported date type is Number on CPU, the same for Ascend except
|
||||
[float64, complex64, complex128].
|
||||
|
||||
Outputs:
|
||||
Tensor, data type is the same as `values`.
|
||||
|
||||
Raises:
|
||||
TypeError: If any type of the inputs is not a Tensor.
|
||||
TypeError: If the data type of these tensors are not the same.
|
||||
ValueError: If any tensor.shape[0] is not 1.
|
||||
ValueError: If length of shape of `values` is less than 1.
|
||||
ValueError: The data type and shape of these tensors are not the same.
|
||||
ValueError: If the shape of these tensors are not the same.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
|
||||
|
@ -2742,31 +2746,6 @@ class ParallelConcat(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""Initialize ParallelConcat"""
|
||||
|
||||
def __infer__(self, values):
|
||||
x_shp = values['shape']
|
||||
x_type = values['dtype']
|
||||
|
||||
validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
|
||||
|
||||
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||
|
||||
first_elem = x_shp[0]
|
||||
for i, elem in enumerate(x_shp[1:]):
|
||||
j = i + 1
|
||||
validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
|
||||
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
|
||||
|
||||
ret_shp = x_shp[0].copy()
|
||||
ret_shp[0] = len(x_shp)
|
||||
self.add_prim_attr('shape', ret_shp)
|
||||
self.add_prim_attr('N', len(x_shp))
|
||||
|
||||
out = {'shape': ret_shp,
|
||||
'dtype': x_type[0],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
def _get_stack_shape(value, x_shape, x_type, axis, prim_name):
|
||||
"""for stack output shape"""
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class ParallelConcatNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ParallelConcatNet, self).__init__()
|
||||
self.net = P.ParallelConcat()
|
||||
|
||||
@ms_function
|
||||
def construct(self, inputs):
|
||||
return self.net(inputs)
|
||||
|
||||
|
||||
def parallel_concat(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float32))
|
||||
data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float32))
|
||||
inputs = [data1, data2]
|
||||
net_ms = ParallelConcatNet()
|
||||
out_ms = net_ms(inputs)
|
||||
expected = np.array([[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8]],
|
||||
[[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def parallel_concat_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
data1 = Tensor(np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.float64))
|
||||
data2 = Tensor(np.array([[[9, 10, 11, 12], [13, 14, 15, 16]]], dtype=np.float64))
|
||||
inputs = [data1, data2]
|
||||
net_ms = ParallelConcatNet()
|
||||
out_ms = net_ms(inputs)
|
||||
expected = np.array([[[1, 2, 3, 4],
|
||||
[5, 6, 7, 8]],
|
||||
[[9, 10, 11, 12],
|
||||
[13, 14, 15, 16]]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for ParallelConcat
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
parallel_concat(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for ParallelConcat
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
parallel_concat_pynative(loss=1.0e-5)
|
Loading…
Reference in New Issue