!43231 [MS][OPS] update UnsortedSegmentSum ops npu dynamic compile static and add irfission

Merge pull request !43231 from luoyuan/UnsortedSegmentSum-dynamic-compile-static-and-add-irfission
This commit is contained in:
i-robot 2022-09-30 07:41:16 +00:00 committed by Gitee
commit 5108f4af77
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
21 changed files with 477 additions and 167 deletions

View File

@ -21,6 +21,7 @@
#include "utils/log_adapter.h"
#include "mindspore/core/ops/core_ops.h"
#include "backend/common/optimizer/helper.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
@ -190,6 +191,12 @@ CNodePtr ConstInputToAttr(const CNodePtr &cnode, const mindspore::HashSet<size_t
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(cnode, new_cnode);
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (primitive->name() == "UnsortedSegmentSum" && backend == kAscendDevice) {
primitive->set_name("UnsortedSegmentSumD");
}
return new_cnode;
}
return cnode;

View File

@ -415,6 +415,7 @@ constexpr auto kUnsortedSegmentMaxOpName = "UnsortedSegmentMax";
constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd";
constexpr auto kUnsortedSegmentSumOpName = "UnsortedSegmentSum";
constexpr auto kUnsortedSegmentSumDOpName = "UnsortedSegmentSumD";
constexpr auto kUpdateCacheOpName = "UpdateCache";
constexpr auto kUpdateStateOpName = "UpdateState";

View File

@ -304,6 +304,12 @@ bool DataConvert::RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_ru
}
}
(void)op_prim->AddAttr(input_name, v);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (op_prim->name() == "UnsortedSegmentSum" && backend == kAscendDevice) {
op_prim->set_name("UnsortedSegmentSumD");
}
(void)op_run_info->index_with_value.emplace_back(std::make_pair(input_index, v));
return true;
}

View File

@ -31,6 +31,7 @@
#include "plugin/device/ascend/optimizer/ir_fusion/fused_batch_norm_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fission/layer_norm_grad_split.h"
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/gather_v2_ds_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/bce_with_logits_loss_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/broadcastto_fission.h"
@ -89,6 +90,7 @@
#include "plugin/device/ascend/optimizer/ir_fusion/transposed_update_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h"
#include "plugin/device/ascend/optimizer/ir_fusion/unsorted_segment_sum_replace.h"
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
#include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h"
#include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h"
@ -215,6 +217,7 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
@ -260,7 +263,8 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>());
ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>());
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumFission>());
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumDFission>());
ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>());
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
ir_fusion_pm->AddPass(std::make_shared<BroadcasttoFission>());
@ -418,6 +422,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
#endif
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());

View File

@ -0,0 +1,163 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
#include <memory>
#include <vector>
#include <algorithm>
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "ir/primitive.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (common::AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto y_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
if (x_shape.empty() || y_shape.empty()) {
return false;
}
if (x_shape[x_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< x_shape[x_shape.size() - 1];
return false;
}
return x_shape.size() > y_shape.size();
}
} // namespace
CNodePtr UnsortedSegmentSumDFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
origin_node->input(kIndex1)};
auto padding = NewCNode(padding_inputs, graph);
MS_EXCEPTION_IF_NULL(padding);
padding->set_scope(origin_node->scope());
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
if (IsDynamic(shape)) {
auto min_shape = common::AnfAlgo::GetInputMinShape(origin_node, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
if (!min_shape.empty() && !max_shape.empty()) {
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
}
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
{base_shape}, padding.get());
} else {
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
{shape}, padding.get());
}
common::AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
return padding;
}
CNodePtr UnsortedSegmentSumDFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const CNodePtr &padding,
const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(padding);
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSumD->name())), padding,
origin_node->input(kIndex2)};
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
unsorted_segment_sum->set_scope(origin_node->scope());
auto shape = common::AnfAlgo::GetOutputInferShape(origin_node, 0);
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
if (IsDynamic(shape)) {
auto min_shape = common::AnfAlgo::GetOutputMinShape(origin_node, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
if (!min_shape.empty() && !max_shape.empty()) {
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
}
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)},
{base_shape}, unsorted_segment_sum.get());
} else {
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
unsorted_segment_sum.get());
}
common::AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(shape[0]), unsorted_segment_sum);
return unsorted_segment_sum;
}
CNodePtr UnsortedSegmentSumDFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
unsorted_segment_sum8};
auto slice = NewCNode(slice_inputs, graph);
MS_EXCEPTION_IF_NULL(slice);
slice->set_scope(unsort_segment_sum->scope());
slice->set_abstract(unsort_segment_sum->abstract());
auto unsort_segment_sum_shape = common::AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(unsort_segment_sum_shape), slice);
return slice;
}
const BaseRef UnsortedSegmentSumDFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimUnsortedSegmentSumD, Xs});
return pattern;
}
const AnfNodePtr UnsortedSegmentSumDFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto origin_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_node);
if (!CheckInputs(origin_node)) {
return nullptr;
}
size_t pad_dim_size;
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
constexpr auto PADSIZE32 = 8;
constexpr auto PADSIZE16 = 16;
if (input_dtype == kNumberTypeFloat32) {
pad_dim_size = PADSIZE32;
} else if (input_dtype == kNumberTypeFloat16) {
pad_dim_size = PADSIZE16;
} else {
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
return nullptr;
}
auto padding = CreatePadding(graph, origin_node, pad_dim_size);
auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size);
return CreateSlice(graph, origin_node, unsorted_segment_sum8);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_
#include <vector>
#include <memory>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/helper.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
class UnsortedSegmentSumDFission : public PatternProcessPass {
public:
explicit UnsortedSegmentSumDFission(bool multigraph = true)
: PatternProcessPass("unsorted_segment_sum_d_fission", multigraph) {}
~UnsortedSegmentSumDFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
private:
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
const size_t &pad_dim_size) const;
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,124 +25,108 @@
namespace mindspore {
namespace opt {
namespace {
bool CheckInputs(const CNodePtr &origin_node) {
MS_EXCEPTION_IF_NULL(origin_node);
if (common::AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
<< ". CNode= " << origin_node->DebugString();
return false;
}
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
auto y_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
if (x_shape.empty() || y_shape.empty()) {
return false;
}
if (x_shape[x_shape.size() - 1] != 1) {
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
<< x_shape[x_shape.size() - 1];
return false;
}
return x_shape.size() > y_shape.size();
}
constexpr size_t kUnsortedSegmentSumInputNum = 3;
} // namespace
CNodePtr UnsortSegmentSumFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
CNodePtr UnsortedSegmentSumFission::CreateConcatD(const FuncGraphPtr &graph, const CNodePtr &sum,
const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
origin_node->input(kIndex1)};
auto padding = NewCNode(padding_inputs, graph);
MS_EXCEPTION_IF_NULL(padding);
padding->set_scope(origin_node->scope());
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
MS_EXCEPTION_IF_NULL(sum);
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
auto x_input = sum->input(kIndex1);
for (size_t i = 0; i < pad_dim_size; ++i) {
concat_inputs.push_back(x_input);
}
auto concat = NewCNode(concat_inputs, graph);
MS_EXCEPTION_IF_NULL(concat);
concat->set_scope(sum->scope());
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sum, 0);
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
if (IsDynamic(shape)) {
auto min_shape = common::AnfAlgo::GetInputMinShape(origin_node, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
auto min_shape = common::AnfAlgo::GetInputMinShape(sum, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(sum, 0);
if (!min_shape.empty() && !max_shape.empty()) {
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
}
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
{base_shape}, padding.get());
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0)},
{base_shape}, concat.get());
} else {
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
{shape}, padding.get());
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0)}, {shape},
concat.get());
}
common::AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
return padding;
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(shape.size() - 1)), concat);
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(pad_dim_size)}), concat);
return concat;
}
CNodePtr UnsortSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
const CNodePtr &padding, const size_t &pad_dim_size) const {
CNodePtr UnsortedSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &orig_sum,
const CNodePtr &concat, const size_t &pad_dim_size) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(padding);
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
origin_node->input(kIndex2)};
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
unsorted_segment_sum->set_scope(origin_node->scope());
auto shape = common::AnfAlgo::GetOutputInferShape(origin_node, 0);
MS_EXCEPTION_IF_NULL(orig_sum);
MS_EXCEPTION_IF_NULL(concat);
std::vector<AnfNodePtr> new_sum_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), concat, orig_sum->input(kIndex2),
orig_sum->input(kIndex3)};
auto new_sum = NewCNode(new_sum_inputs, graph);
MS_EXCEPTION_IF_NULL(new_sum);
new_sum->set_scope(orig_sum->scope());
auto shape = common::AnfAlgo::GetOutputInferShape(orig_sum, 0);
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
if (IsDynamic(shape)) {
auto min_shape = common::AnfAlgo::GetOutputMinShape(origin_node, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
auto min_shape = common::AnfAlgo::GetOutputMinShape(orig_sum, 0);
auto max_shape = common::AnfAlgo::GetInputMaxShape(orig_sum, 0);
if (!min_shape.empty() && !max_shape.empty()) {
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
}
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)},
{base_shape}, unsorted_segment_sum.get());
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(orig_sum, 0)}, {base_shape},
new_sum.get());
} else {
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
unsorted_segment_sum.get());
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(orig_sum, 0)}, {shape},
new_sum.get());
}
return new_sum;
}
common::AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(shape[0]), unsorted_segment_sum);
return unsorted_segment_sum;
}
CNodePtr UnsortSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const {
CNodePtr UnsortedSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &orig_sum,
const CNodePtr &new_sum) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
unsorted_segment_sum8};
MS_EXCEPTION_IF_NULL(orig_sum);
MS_EXCEPTION_IF_NULL(new_sum);
auto orig_sum_shape = common::AnfAlgo::GetOutputInferShape(orig_sum, 0);
std::vector<int64_t> offsets(orig_sum_shape.size(), 0);
auto offsets_input = CreateShapeValueNode(graph, offsets, true);
auto size_input = CreateShapeValueNode(graph, orig_sum_shape, true);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), new_sum,
offsets_input, size_input};
auto slice = NewCNode(slice_inputs, graph);
MS_EXCEPTION_IF_NULL(slice);
slice->set_scope(unsort_segment_sum->scope());
slice->set_abstract(unsort_segment_sum->abstract());
auto unsort_segment_sum_shape = common::AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(unsort_segment_sum_shape), slice);
slice->set_scope(orig_sum->scope());
slice->set_abstract(orig_sum->abstract());
return slice;
}
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
const BaseRef UnsortedSegmentSumFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs});
return pattern;
}
const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr UnsortedSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto origin_node = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_node);
if (!CheckInputs(origin_node)) {
auto sum = CheckAnfNodeIfCNodeAndInputSize(node, kUnsortedSegmentSumInputNum);
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sum, 0);
if (x_shape.size() <= 1 || x_shape.back() != 1) {
return nullptr;
}
size_t pad_dim_size;
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0);
constexpr auto PADSIZE32 = 8;
constexpr auto PADSIZE16 = 16;
if (input_dtype == kNumberTypeFloat32) {
@ -150,13 +134,13 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
} else if (input_dtype == kNumberTypeFloat16) {
pad_dim_size = PADSIZE16;
} else {
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change.";
return nullptr;
}
auto padding = CreatePadding(graph, origin_node, pad_dim_size);
auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size);
return CreateSlice(graph, origin_node, unsorted_segment_sum8);
auto concat = CreateConcatD(graph, sum, pad_dim_size);
auto new_sum = CreateUnsortedSegmentSum(graph, sum, concat, pad_dim_size);
return CreateSlice(graph, sum, new_sum);
}
} // namespace opt
} // namespace mindspore

View File

@ -24,20 +24,19 @@
namespace mindspore {
namespace opt {
class UnsortSegmentSumFission : public PatternProcessPass {
class UnsortedSegmentSumFission : public PatternProcessPass {
public:
explicit UnsortSegmentSumFission(bool multigraph = true)
explicit UnsortedSegmentSumFission(bool multigraph = true)
: PatternProcessPass("unsorted_segment_sum_fission", multigraph) {}
~UnsortSegmentSumFission() override = default;
~UnsortedSegmentSumFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
private:
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
CNodePtr CreateConcatD(const FuncGraphPtr &graph, const CNodePtr &sum, const size_t &pad_dim_size) const;
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &orig_sum, const CNodePtr &concat,
const size_t &pad_dim_size) const;
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
const CNodePtr &unsorted_segment_sum8) const;
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &orig_sum, const CNodePtr &new_sum) const;
};
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,103 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ir_fusion/unsorted_segment_sum_replace.h"
#include <string>
#include <vector>
#include <memory>
#include <set>
#include "utils/hash_set.h"
#include "backend/common/optimizer/const_input_to_attr.h"
#include "kernel/kernel_build_info.h"
#include "include/common/utils/utils.h"
#include "backend/common/session/kernel_graph.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "runtime/device/kernel_info.h"
#include "utils/ms_context.h"
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
namespace mindspore::opt {
namespace {
constexpr auto kNumSegments = "num_segments";
} // namespace
const BaseRef UnsortedSegmentSumReplace::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto prim = std::make_shared<Primitive>(kUnsortedSegmentSumDOpName);
return VectorRef({prim, Xs});
}
const AnfNodePtr UnsortedSegmentSumReplace::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
if (cnode->inputs().size() == 0) {
return nullptr;
}
if (!common::AnfAlgo::HasNodeAttr(kNumSegments, cnode)) {
MS_LOG(INFO) << "Has no num_segments attr.";
return nullptr;
}
// Copy a new node to check supported.
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kUnsortedSegmentSumOpName))};
(void)new_inputs.insert(new_inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
CNodePtr new_cnode = NewCNode(new_inputs, func_graph);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
CheckCNodeInputSize(new_cnode, kUnsortedSegmentSumInputTensorNum);
// Convert attr num_segments to the tensor input
auto value = primitive->GetAttr(kNumSegments);
if (value == nullptr) {
MS_LOG(INFO) << "Can not get attr[" << kNumSegments << "] num_segments.";
return nullptr;
}
tensor::TensorPtr tensor_ptr = nullptr;
if (value->isa<tensor::Tensor>()) {
tensor_ptr = value->cast<tensor::TensorPtr>();
} else if (value->isa<Scalar>()) {
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
} else if (value->isa<ValueTuple>()) {
tensor_ptr = opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
} else {
MS_LOG(INFO) << "The value of attr[" << kNumSegments << "] should be a tensor or scalar or value tuple.";
return nullptr;
}
if (tensor_ptr == nullptr) {
MS_LOG(INFO) << "Convert attr[" << kNumSegments << "] to tensor value failed.";
return nullptr;
}
auto value_node = kernel_graph->NewValueNode(tensor_ptr);
MS_EXCEPTION_IF_NULL(value_node);
new_inputs.push_back(value_node);
new_cnode->set_inputs(new_inputs);
if (!CheckAICoreSupportedAny(new_cnode)) {
MS_LOG(INFO) << "Replace unsorted_segment_sum_d op to unsorted_segment_sum op failed.";
return nullptr;
}
MS_LOG(INFO) << "Replace unsorted_segment_sum_d op to unsorted_segment_sum op success. use tbe aicore.";
return new_cnode;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_
#include <memory>
#include <string>
#include "backend/common/optimizer/optimizer.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace opt {
class UnsortedSegmentSumReplace : public PatternProcessPass {
public:
explicit UnsortedSegmentSumReplace(bool multigraph = true, const string &name = "unsorted_segment_sum_replace")
: PatternProcessPass(name, multigraph) {}
~UnsortedSegmentSumReplace() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_

View File

@ -287,6 +287,9 @@ CNodePtr AscendVmOpAdapter::CreateTargetOp(const CNodePtr &origin_op,
target_op->set_primal_attrs(origin_op->primal_attrs());
target_op->set_attrs(origin_op->attrs());
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
if (target_primitive->name() == "UnsortedSegmentSum") {
target_primitive->set_name("UnsortedSegmentSumD");
}
MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << " debug string:" << target_op->DebugString()
<< " from " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
<< ", is dynamic shape:" << is_dynamic;

View File

@ -79,7 +79,6 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kResizeNearestNeighborGradOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kScatterNdOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSimpleMeanGradOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSliceGradOpName, 2, 3);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSparseGatherV2OpName, 2);

View File

@ -47,6 +47,7 @@ class MS_CORE_API Named : public Value {
const std::string &name() const { return name_; }
/// \brief Setting name of object.
///
/// \param[in] name The name set for the object.
/// \no return.
void set_name(const std::string &name) { name_ = name; }
/// \brief Check whether two Named objects are the same.

View File

@ -226,6 +226,7 @@ constexpr auto kSparseSparseMinimum = "SparseSparseMinimum";
constexpr auto kBroadcastTo = "BroadcastTo";
constexpr auto kSparseReshape = "SparseReshape";
constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum";
constexpr auto kUnsortedSegmentSumD = "UnsortedSegmentSumD";
constexpr auto kUnsortedSegmentProd = "UnsortedSegmentProd";
constexpr auto kBincount = "Bincount";
@ -540,6 +541,7 @@ GVAR_DEF(PrimitivePtr, kPrimUnstack, std::make_shared<Primitive>(kUnstack));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMax, std::make_shared<Primitive>("UnsortedSegmentMax"));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentProd, std::make_shared<Primitive>(kUnsortedSegmentProd));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSum, std::make_shared<Primitive>(kUnsortedSegmentSum));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSumD, std::make_shared<Primitive>(kUnsortedSegmentSumD));
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMin, std::make_shared<Primitive>("UnsortedSegmentMin"));
GVAR_DEF(PrimitivePtr, kPrimConcatOffset, std::make_shared<Primitive>("ConcatOffset"));
GVAR_DEF(PrimitivePtr, kPrimConcatOffsetV1, std::make_shared<Primitive>("ConcatOffsetV1"));

View File

@ -291,7 +291,6 @@ from .minimum_grad_ds import _minimum_grad_ds_tbe
from .concat import _concat_tbe
from .concat_ds import _concat_ds_tbe
from .slice import _slice_tbe
from .slice_ds import _slice_ds_tbe
from .sign import _sign_tbe
from .sign_ds import _sign_ds_tbe
from .greater import _greater_tbe

View File

@ -19,25 +19,36 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
slice_op_info = TBERegOp("Slice") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("slice_d.so") \
.binfile_name("slice.so") \
.compute_cost(10) \
.kernel_name("slice_d") \
.kernel_name("slice") \
.partial_flag(True) \
.is_dynamic_format(True) \
.attr("begin", "required", "listInt", "all") \
.attr("size", "required", "listInt", "all") \
.dynamic_compile_static(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "begin", False, "required", "all", "optional") \
.input(2, "size", False, "required", "all", "optional") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
.get_op_info()

View File

@ -1,57 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Slice op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
slice_ds_op_info = TBERegOp("Slice") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("slice.so") \
.compute_cost(10) \
.kernel_name("slice") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "begin", False, "required", "all") \
.input(2, "size", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(slice_ds_op_info)
def _slice_ds_tbe():
"""Slice TBE register"""
return

View File

@ -16,7 +16,7 @@
"""UnsortedSegmentSum op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSumD") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("unsorted_segment_sum_d.so") \

View File

@ -24,6 +24,7 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
.kernel_name("unsorted_segment_sum") \
.partial_flag(True) \
.dynamic_shape(True) \
.dynamic_compile_static(True) \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.input(1, "segment_ids", False, "required", "all", "optional") \
.input(2, "num_segments", False, "required", "all", "optional") \

View File

@ -14,10 +14,11 @@
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_fission.h"
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "include/common/debug/anf_ir_dump.h"
#include "plugin/device/ascend/optimizer/mindir/ascend_vm_op_adapter.h"
namespace mindspore {
namespace opt {
@ -41,7 +42,8 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>());
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
pm->AddPass(std::make_shared<opt::UnsortedSegmentSumDFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
@ -61,7 +63,8 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>());
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
pm->AddPass(std::make_shared<opt::UnsortedSegmentSumDFission>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);

View File

@ -22,7 +22,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum()
num_segments = 4
padding = Primitive('Padding')
op_slice = Primitive('Slice')
op_unsorted_segment_sum = Primitive('UnsortedSegmentSum')
op_unsorted_segment_sum = Primitive('UnsortedSegmentSumD')
class FnDict: