forked from mindspore-Ecosystem/mindspore
!43231 [MS][OPS] update UnsortedSegmentSum ops npu dynamic compile static and add irfission
Merge pull request !43231 from luoyuan/UnsortedSegmentSum-dynamic-compile-static-and-add-irfission
This commit is contained in:
commit
5108f4af77
|
@ -21,6 +21,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -190,6 +191,12 @@ CNodePtr ConstInputToAttr(const CNodePtr &cnode, const mindspore::HashSet<size_t
|
|||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(cnode, new_cnode);
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (primitive->name() == "UnsortedSegmentSum" && backend == kAscendDevice) {
|
||||
primitive->set_name("UnsortedSegmentSumD");
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
return cnode;
|
||||
|
|
|
@ -415,6 +415,7 @@ constexpr auto kUnsortedSegmentMaxOpName = "UnsortedSegmentMax";
|
|||
constexpr auto kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
|
||||
constexpr auto kUnsortedSegmentProdOpName = "UnsortedSegmentProd";
|
||||
constexpr auto kUnsortedSegmentSumOpName = "UnsortedSegmentSum";
|
||||
constexpr auto kUnsortedSegmentSumDOpName = "UnsortedSegmentSumD";
|
||||
constexpr auto kUpdateCacheOpName = "UpdateCache";
|
||||
constexpr auto kUpdateStateOpName = "UpdateState";
|
||||
|
||||
|
|
|
@ -304,6 +304,12 @@ bool DataConvert::RunOpConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_ru
|
|||
}
|
||||
}
|
||||
(void)op_prim->AddAttr(input_name, v);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (op_prim->name() == "UnsortedSegmentSum" && backend == kAscendDevice) {
|
||||
op_prim->set_name("UnsortedSegmentSumD");
|
||||
}
|
||||
(void)op_run_info->index_with_value.emplace_back(std::make_pair(input_index, v));
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ir_fusion/fused_batch_norm_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/layer_norm_grad_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/gather_v2_ds_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/bce_with_logits_loss_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/broadcastto_fission.h"
|
||||
|
@ -89,6 +90,7 @@
|
|||
#include "plugin/device/ascend/optimizer/ir_fusion/transposed_update_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/softmax_dropout_do_mask_v3_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/conv2d_backprop_input_dilation_fusion.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/unsorted_segment_sum_replace.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/insert_trans_op.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/trans_op_format_refine.h"
|
||||
#include "plugin/device/ascend/optimizer/format_type/dynamic_rnn_grad_reformat.h"
|
||||
|
@ -215,6 +217,7 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
|
|||
|
||||
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
||||
|
@ -260,7 +263,8 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<ReduceMinFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortSegmentSumFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumDFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<GatherV2DsFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BroadcasttoFission>());
|
||||
|
@ -418,6 +422,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
#endif
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (common::AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
auto y_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
||||
if (x_shape.empty() || y_shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (x_shape[x_shape.size() - 1] != 1) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
|
||||
<< x_shape[x_shape.size() - 1];
|
||||
return false;
|
||||
}
|
||||
return x_shape.size() > y_shape.size();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CNodePtr UnsortedSegmentSumDFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
|
||||
origin_node->input(kIndex1)};
|
||||
auto padding = NewCNode(padding_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
padding->set_scope(origin_node->scope());
|
||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
if (IsDynamic(shape)) {
|
||||
auto min_shape = common::AnfAlgo::GetInputMinShape(origin_node, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
|
||||
if (!min_shape.empty() && !max_shape.empty()) {
|
||||
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
}
|
||||
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
|
||||
{base_shape}, padding.get());
|
||||
} else {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
|
||||
{shape}, padding.get());
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
|
||||
return padding;
|
||||
}
|
||||
|
||||
CNodePtr UnsortedSegmentSumDFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const CNodePtr &padding,
|
||||
const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSumD->name())), padding,
|
||||
origin_node->input(kIndex2)};
|
||||
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
|
||||
unsorted_segment_sum->set_scope(origin_node->scope());
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(origin_node, 0);
|
||||
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
if (IsDynamic(shape)) {
|
||||
auto min_shape = common::AnfAlgo::GetOutputMinShape(origin_node, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
|
||||
if (!min_shape.empty() && !max_shape.empty()) {
|
||||
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
}
|
||||
|
||||
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)},
|
||||
{base_shape}, unsorted_segment_sum.get());
|
||||
} else {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
|
||||
unsorted_segment_sum.get());
|
||||
}
|
||||
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(shape[0]), unsorted_segment_sum);
|
||||
return unsorted_segment_sum;
|
||||
}
|
||||
|
||||
CNodePtr UnsortedSegmentSumDFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
|
||||
unsorted_segment_sum8};
|
||||
auto slice = NewCNode(slice_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
slice->set_scope(unsort_segment_sum->scope());
|
||||
slice->set_abstract(unsort_segment_sum->abstract());
|
||||
auto unsort_segment_sum_shape = common::AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
|
||||
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(unsort_segment_sum_shape), slice);
|
||||
return slice;
|
||||
}
|
||||
|
||||
const BaseRef UnsortedSegmentSumDFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimUnsortedSegmentSumD, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr UnsortedSegmentSumDFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto origin_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (!CheckInputs(origin_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
size_t pad_dim_size;
|
||||
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
|
||||
constexpr auto PADSIZE32 = 8;
|
||||
constexpr auto PADSIZE16 = 16;
|
||||
if (input_dtype == kNumberTypeFloat32) {
|
||||
pad_dim_size = PADSIZE32;
|
||||
} else if (input_dtype == kNumberTypeFloat16) {
|
||||
pad_dim_size = PADSIZE16;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto padding = CreatePadding(graph, origin_node, pad_dim_size);
|
||||
auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size);
|
||||
return CreateSlice(graph, origin_node, unsorted_segment_sum8);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class UnsortedSegmentSumDFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit UnsortedSegmentSumDFission(bool multigraph = true)
|
||||
: PatternProcessPass("unsorted_segment_sum_d_fission", multigraph) {}
|
||||
~UnsortedSegmentSumDFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
|
||||
const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_UNSORTED_SEGMENT_SUM_D_FISSION_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,124 +25,108 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool CheckInputs(const CNodePtr &origin_node) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (common::AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum
|
||||
<< ". CNode= " << origin_node->DebugString();
|
||||
return false;
|
||||
}
|
||||
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
auto y_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1);
|
||||
if (x_shape.empty() || y_shape.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (x_shape[x_shape.size() - 1] != 1) {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is "
|
||||
<< x_shape[x_shape.size() - 1];
|
||||
return false;
|
||||
}
|
||||
return x_shape.size() > y_shape.size();
|
||||
}
|
||||
constexpr size_t kUnsortedSegmentSumInputNum = 3;
|
||||
} // namespace
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const size_t &pad_dim_size) const {
|
||||
CNodePtr UnsortedSegmentSumFission::CreateConcatD(const FuncGraphPtr &graph, const CNodePtr &sum,
|
||||
const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
std::vector<AnfNodePtr> padding_inputs = {NewValueNode(std::make_shared<Primitive>(kPaddingOpName)),
|
||||
origin_node->input(kIndex1)};
|
||||
auto padding = NewCNode(padding_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
padding->set_scope(origin_node->scope());
|
||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(sum);
|
||||
std::vector<AnfNodePtr> concat_inputs = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
auto x_input = sum->input(kIndex1);
|
||||
for (size_t i = 0; i < pad_dim_size; ++i) {
|
||||
concat_inputs.push_back(x_input);
|
||||
}
|
||||
auto concat = NewCNode(concat_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
concat->set_scope(sum->scope());
|
||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sum, 0);
|
||||
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
if (IsDynamic(shape)) {
|
||||
auto min_shape = common::AnfAlgo::GetInputMinShape(origin_node, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
|
||||
auto min_shape = common::AnfAlgo::GetInputMinShape(sum, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(sum, 0);
|
||||
if (!min_shape.empty() && !max_shape.empty()) {
|
||||
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
}
|
||||
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
|
||||
{base_shape}, padding.get());
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0)},
|
||||
{base_shape}, concat.get());
|
||||
} else {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0)},
|
||||
{shape}, padding.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0)}, {shape},
|
||||
concat.get());
|
||||
}
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPadDimSize, MakeValue(SizeToLong(pad_dim_size)), padding);
|
||||
return padding;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(SizeToLong(shape.size() - 1)), concat);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(std::vector<int64_t>{SizeToLong(pad_dim_size)}), concat);
|
||||
return concat;
|
||||
}
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node,
|
||||
const CNodePtr &padding, const size_t &pad_dim_size) const {
|
||||
CNodePtr UnsortedSegmentSumFission::CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &orig_sum,
|
||||
const CNodePtr &concat, const size_t &pad_dim_size) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(padding);
|
||||
std::vector<AnfNodePtr> unsorted_segment_sum8_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), padding,
|
||||
origin_node->input(kIndex2)};
|
||||
auto unsorted_segment_sum = NewCNode(unsorted_segment_sum8_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum);
|
||||
unsorted_segment_sum->set_scope(origin_node->scope());
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(origin_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(orig_sum);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
std::vector<AnfNodePtr> new_sum_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(prim::kPrimUnsortedSegmentSum->name())), concat, orig_sum->input(kIndex2),
|
||||
orig_sum->input(kIndex3)};
|
||||
auto new_sum = NewCNode(new_sum_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_sum);
|
||||
new_sum->set_scope(orig_sum->scope());
|
||||
auto shape = common::AnfAlgo::GetOutputInferShape(orig_sum, 0);
|
||||
shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
if (IsDynamic(shape)) {
|
||||
auto min_shape = common::AnfAlgo::GetOutputMinShape(origin_node, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(origin_node, 0);
|
||||
auto min_shape = common::AnfAlgo::GetOutputMinShape(orig_sum, 0);
|
||||
auto max_shape = common::AnfAlgo::GetInputMaxShape(orig_sum, 0);
|
||||
if (!min_shape.empty() && !max_shape.empty()) {
|
||||
min_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
max_shape[shape.size() - 1] = SizeToLong(pad_dim_size);
|
||||
}
|
||||
|
||||
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)},
|
||||
{base_shape}, unsorted_segment_sum.get());
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({common::AnfAlgo::GetOutputInferDataType(orig_sum, 0)}, {base_shape},
|
||||
new_sum.get());
|
||||
} else {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(origin_node, 0)}, {shape},
|
||||
unsorted_segment_sum.get());
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({common::AnfAlgo::GetOutputInferDataType(orig_sum, 0)}, {shape},
|
||||
new_sum.get());
|
||||
}
|
||||
|
||||
common::AnfAlgo::SetNodeAttr(kAttrNumSegments, MakeValue(shape[0]), unsorted_segment_sum);
|
||||
return unsorted_segment_sum;
|
||||
return new_sum;
|
||||
}
|
||||
|
||||
CNodePtr UnsortSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const {
|
||||
CNodePtr UnsortedSegmentSumFission::CreateSlice(const FuncGraphPtr &graph, const CNodePtr &orig_sum,
|
||||
const CNodePtr &new_sum) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(unsort_segment_sum);
|
||||
MS_EXCEPTION_IF_NULL(unsorted_segment_sum8);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)),
|
||||
unsorted_segment_sum8};
|
||||
MS_EXCEPTION_IF_NULL(orig_sum);
|
||||
MS_EXCEPTION_IF_NULL(new_sum);
|
||||
auto orig_sum_shape = common::AnfAlgo::GetOutputInferShape(orig_sum, 0);
|
||||
std::vector<int64_t> offsets(orig_sum_shape.size(), 0);
|
||||
auto offsets_input = CreateShapeValueNode(graph, offsets, true);
|
||||
auto size_input = CreateShapeValueNode(graph, orig_sum_shape, true);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), new_sum,
|
||||
offsets_input, size_input};
|
||||
auto slice = NewCNode(slice_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
slice->set_scope(unsort_segment_sum->scope());
|
||||
slice->set_abstract(unsort_segment_sum->abstract());
|
||||
auto unsort_segment_sum_shape = common::AnfAlgo::GetOutputInferShape(unsort_segment_sum, 0);
|
||||
std::vector<size_t> offsets(unsort_segment_sum_shape.size(), 0);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(Convert2Long(offsets)), slice);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(unsort_segment_sum_shape), slice);
|
||||
slice->set_scope(orig_sum->scope());
|
||||
slice->set_abstract(orig_sum->abstract());
|
||||
return slice;
|
||||
}
|
||||
|
||||
const BaseRef UnsortSegmentSumFission::DefinePattern() const {
|
||||
const BaseRef UnsortedSegmentSumFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimUnsortedSegmentSum, Xs});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
const AnfNodePtr UnsortedSegmentSumFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto origin_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
if (!CheckInputs(origin_node)) {
|
||||
auto sum = CheckAnfNodeIfCNodeAndInputSize(node, kUnsortedSegmentSumInputNum);
|
||||
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(sum, 0);
|
||||
if (x_shape.size() <= 1 || x_shape.back() != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
size_t pad_dim_size;
|
||||
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(origin_node, 0);
|
||||
auto input_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(sum, 0);
|
||||
constexpr auto PADSIZE32 = 8;
|
||||
constexpr auto PADSIZE16 = 16;
|
||||
if (input_dtype == kNumberTypeFloat32) {
|
||||
|
@ -150,13 +134,13 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con
|
|||
} else if (input_dtype == kNumberTypeFloat16) {
|
||||
pad_dim_size = PADSIZE16;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change";
|
||||
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto padding = CreatePadding(graph, origin_node, pad_dim_size);
|
||||
auto unsorted_segment_sum8 = CreateUnsortedSegmentSum(graph, origin_node, padding, pad_dim_size);
|
||||
return CreateSlice(graph, origin_node, unsorted_segment_sum8);
|
||||
auto concat = CreateConcatD(graph, sum, pad_dim_size);
|
||||
auto new_sum = CreateUnsortedSegmentSum(graph, sum, concat, pad_dim_size);
|
||||
return CreateSlice(graph, sum, new_sum);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,20 +24,19 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class UnsortSegmentSumFission : public PatternProcessPass {
|
||||
class UnsortedSegmentSumFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit UnsortSegmentSumFission(bool multigraph = true)
|
||||
explicit UnsortedSegmentSumFission(bool multigraph = true)
|
||||
: PatternProcessPass("unsorted_segment_sum_fission", multigraph) {}
|
||||
~UnsortSegmentSumFission() override = default;
|
||||
~UnsortedSegmentSumFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr CreatePadding(const FuncGraphPtr &graph, const CNodePtr &origin_node, const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &origin_node, const CNodePtr &padding,
|
||||
CNodePtr CreateConcatD(const FuncGraphPtr &graph, const CNodePtr &sum, const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateUnsortedSegmentSum(const FuncGraphPtr &graph, const CNodePtr &orig_sum, const CNodePtr &concat,
|
||||
const size_t &pad_dim_size) const;
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_sum,
|
||||
const CNodePtr &unsorted_segment_sum8) const;
|
||||
CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &orig_sum, const CNodePtr &new_sum) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/unsorted_segment_sum_replace.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "utils/hash_set.h"
|
||||
#include "backend/common/optimizer/const_input_to_attr.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "plugin/device/ascend/optimizer/optimizer_factory.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr auto kNumSegments = "num_segments";
|
||||
} // namespace
|
||||
|
||||
const BaseRef UnsortedSegmentSumReplace::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kUnsortedSegmentSumDOpName);
|
||||
return VectorRef({prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr UnsortedSegmentSumReplace::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
if (cnode->inputs().size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!common::AnfAlgo::HasNodeAttr(kNumSegments, cnode)) {
|
||||
MS_LOG(INFO) << "Has no num_segments attr.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Copy a new node to check supported.
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kUnsortedSegmentSumOpName))};
|
||||
(void)new_inputs.insert(new_inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, func_graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
CheckCNodeInputSize(new_cnode, kUnsortedSegmentSumInputTensorNum);
|
||||
// Convert attr num_segments to the tensor input
|
||||
auto value = primitive->GetAttr(kNumSegments);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(INFO) << "Can not get attr[" << kNumSegments << "] num_segments.";
|
||||
return nullptr;
|
||||
}
|
||||
tensor::TensorPtr tensor_ptr = nullptr;
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
} else if (value->isa<Scalar>()) {
|
||||
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
tensor_ptr = opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
} else {
|
||||
MS_LOG(INFO) << "The value of attr[" << kNumSegments << "] should be a tensor or scalar or value tuple.";
|
||||
return nullptr;
|
||||
}
|
||||
if (tensor_ptr == nullptr) {
|
||||
MS_LOG(INFO) << "Convert attr[" << kNumSegments << "] to tensor value failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = kernel_graph->NewValueNode(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
new_inputs.push_back(value_node);
|
||||
new_cnode->set_inputs(new_inputs);
|
||||
if (!CheckAICoreSupportedAny(new_cnode)) {
|
||||
MS_LOG(INFO) << "Replace unsorted_segment_sum_d op to unsorted_segment_sum op failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Replace unsorted_segment_sum_d op to unsorted_segment_sum op success. use tbe aicore.";
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class UnsortedSegmentSumReplace : public PatternProcessPass {
|
||||
public:
|
||||
explicit UnsortedSegmentSumReplace(bool multigraph = true, const string &name = "unsorted_segment_sum_replace")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~UnsortedSegmentSumReplace() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSSION_UNSORTED_SEGMENT_SUM_H_
|
|
@ -287,6 +287,9 @@ CNodePtr AscendVmOpAdapter::CreateTargetOp(const CNodePtr &origin_op,
|
|||
target_op->set_primal_attrs(origin_op->primal_attrs());
|
||||
target_op->set_attrs(origin_op->attrs());
|
||||
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
|
||||
if (target_primitive->name() == "UnsortedSegmentSum") {
|
||||
target_primitive->set_name("UnsortedSegmentSumD");
|
||||
}
|
||||
MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << " debug string:" << target_op->DebugString()
|
||||
<< " from " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
|
||||
<< ", is dynamic shape:" << is_dynamic;
|
||||
|
|
|
@ -79,7 +79,6 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kResizeNearestNeighborGradOpName, 1);
|
|||
RER_ASCEND_STATIC_CONST_TO_ATTR(kScatterNdOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSimpleMeanGradOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSliceGradOpName, 2, 3);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSparseGatherV2OpName, 2);
|
||||
|
|
|
@ -47,6 +47,7 @@ class MS_CORE_API Named : public Value {
|
|||
const std::string &name() const { return name_; }
|
||||
/// \brief Setting name of object.
|
||||
///
|
||||
/// \param[in] name The name set for the object.
|
||||
/// \no return.
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
/// \brief Check whether two Named objects are the same.
|
||||
|
|
|
@ -226,6 +226,7 @@ constexpr auto kSparseSparseMinimum = "SparseSparseMinimum";
|
|||
constexpr auto kBroadcastTo = "BroadcastTo";
|
||||
constexpr auto kSparseReshape = "SparseReshape";
|
||||
constexpr auto kUnsortedSegmentSum = "UnsortedSegmentSum";
|
||||
constexpr auto kUnsortedSegmentSumD = "UnsortedSegmentSumD";
|
||||
constexpr auto kUnsortedSegmentProd = "UnsortedSegmentProd";
|
||||
constexpr auto kBincount = "Bincount";
|
||||
|
||||
|
@ -540,6 +541,7 @@ GVAR_DEF(PrimitivePtr, kPrimUnstack, std::make_shared<Primitive>(kUnstack));
|
|||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMax, std::make_shared<Primitive>("UnsortedSegmentMax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentProd, std::make_shared<Primitive>(kUnsortedSegmentProd));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSum, std::make_shared<Primitive>(kUnsortedSegmentSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentSumD, std::make_shared<Primitive>(kUnsortedSegmentSumD));
|
||||
GVAR_DEF(PrimitivePtr, kPrimUnsortedSegmentMin, std::make_shared<Primitive>("UnsortedSegmentMin"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcatOffset, std::make_shared<Primitive>("ConcatOffset"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimConcatOffsetV1, std::make_shared<Primitive>("ConcatOffsetV1"));
|
||||
|
|
|
@ -291,7 +291,6 @@ from .minimum_grad_ds import _minimum_grad_ds_tbe
|
|||
from .concat import _concat_tbe
|
||||
from .concat_ds import _concat_ds_tbe
|
||||
from .slice import _slice_tbe
|
||||
from .slice_ds import _slice_ds_tbe
|
||||
from .sign import _sign_tbe
|
||||
from .sign_ds import _sign_ds_tbe
|
||||
from .greater import _greater_tbe
|
||||
|
|
|
@ -19,25 +19,36 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|||
slice_op_info = TBERegOp("Slice") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("slice_d.so") \
|
||||
.binfile_name("slice.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("slice_d") \
|
||||
.kernel_name("slice") \
|
||||
.partial_flag(True) \
|
||||
.is_dynamic_format(True) \
|
||||
.attr("begin", "required", "listInt", "all") \
|
||||
.attr("size", "required", "listInt", "all") \
|
||||
.dynamic_compile_static(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "begin", False, "required", "all", "optional") \
|
||||
.input(2, "size", False, "required", "all", "optional") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Slice op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
slice_ds_op_info = TBERegOp("Slice") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("slice.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("slice") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "begin", False, "required", "all") \
|
||||
.input(2, "size", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(slice_ds_op_info)
|
||||
def _slice_ds_tbe():
|
||||
"""Slice TBE register"""
|
||||
return
|
|
@ -16,7 +16,7 @@
|
|||
"""UnsortedSegmentSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSum") \
|
||||
unsorted_segment_sum_op_info = TBERegOp("UnsortedSegmentSumD") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("unsorted_segment_sum_d.so") \
|
||||
|
|
|
@ -24,6 +24,7 @@ unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \
|
|||
.kernel_name("unsorted_segment_sum") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.dynamic_compile_static(True) \
|
||||
.input(0, "x", False, "required", "all", reshape_type="NC") \
|
||||
.input(1, "segment_ids", False, "required", "all", "optional") \
|
||||
.input(2, "num_segments", False, "required", "all", "optional") \
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/unsorted_segment_sum_d_fission.h"
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/ascend_vm_op_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -41,7 +42,8 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_fission) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>());
|
||||
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
pm->AddPass(std::make_shared<opt::UnsortedSegmentSumDFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
|
@ -61,7 +63,8 @@ TEST_F(TestHWUnsortedSegmentSumFission, test_no_fission) {
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::UnsortSegmentSumFission>());
|
||||
pm->AddPass(std::make_shared<opt::AscendVmOpAdapter>());
|
||||
pm->AddPass(std::make_shared<opt::UnsortedSegmentSumDFission>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum()
|
|||
num_segments = 4
|
||||
padding = Primitive('Padding')
|
||||
op_slice = Primitive('Slice')
|
||||
op_unsorted_segment_sum = Primitive('UnsortedSegmentSum')
|
||||
op_unsorted_segment_sum = Primitive('UnsortedSegmentSumD')
|
||||
|
||||
|
||||
class FnDict:
|
||||
|
|
Loading…
Reference in New Issue