!15172 Clean GraphKernel's codes from frontend

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng,@gaoxiong1
Signed-off-by: @dylangeng,@dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-19 09:34:35 +08:00 committed by Gitee
commit 0fd1726e79
44 changed files with 124 additions and 1278 deletions

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include <list>
#include <memory>
#include <vector>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
@ -110,8 +111,6 @@
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
#include "backend/optimizer/ascend/ir_fission/splitv_fission.h"
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
@ -199,19 +198,6 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
}
} // namespace
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
MS_EXCEPTION_IF_NULL(optimizer);
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
MS_EXCEPTION_IF_NULL(common_process);
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
optimizer->AddPassManager(common_process);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();

View File

@ -24,7 +24,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt

View File

@ -344,7 +344,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
auto real_input_node = kernel_with_index.first;
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
if (kernel::IsWeightBoundary(real_input_node)) {
// weight
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
if (origin_type == kTypeUnknown) {
@ -358,9 +358,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
// In graph kernel, we check parameter,
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
new_inputs.push_back(cur_input);
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
auto cast =
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(cast);

View File

@ -78,21 +78,12 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
return nullptr;
}
std::vector<AnfNodePtr> todos = {node};
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
}
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < in_num; ++i) {
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
<< cnode->DebugString() << "]";
}
CNodePtr cnode = node->cast<CNodePtr>();
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < in_num; ++i) {
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
<< cnode->DebugString() << "]";
}
}
return nullptr;

View File

@ -30,8 +30,7 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> make_tuple_inputs;
@ -42,32 +41,23 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
AnfNodePtr replace_node = nullptr;
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
const auto origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
auto idx = NewValueNode(SizeToLong(output_idx));
MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(output_idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
if (need_insert_cast[output_idx]) {
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
TypeId origin_type(kTypeUnknown);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
}
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
if (origin_type != device_type) {
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
infer_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
} else {
replace_node = getitem;
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get());
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
if (origin_type != device_type) {
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
origin_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
} else {
replace_node = getitem;
@ -81,8 +71,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
return make_tuple;
} // namespace
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<bool> &need_insert_cast) {
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
@ -92,23 +81,14 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// Single output
if (!cnode->Type()->isa<Tuple>()) {
if (!need_insert_cast[0]) {
return cnode;
}
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
TypeId origin_type(kTypeUnknown);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
}
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode;
if (origin_type != device_type) {
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
infer_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
origin_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
@ -119,69 +99,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return replace_node;
}
// Multiple output
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
}
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
// insert cast for ops in graph kernel.
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
for (auto &output : outputs) {
size_t index = 0;
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(tuple_index_value);
if (!tuple_index_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
}
index = tuple_index_value->cast<Int64ImmPtr>()->value();
}
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
}
for (auto &t : todo) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
// process input
CNodePtr t_cnode = t->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(t_cnode);
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
AnfNodePtr t_new_node_1 = nullptr;
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
// process output
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
if (iter != graph_rets.end()) {
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
}
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
} else {
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
}
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
(void)mng->Replace(t, t_new_node_1);
}
}
// insert cast for graph kernel.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode);
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
return InsertCastForMultipleOutput(func_graph, cnode);
}
} // namespace
@ -196,11 +114,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
return ProcessGraphKernelOp(func_graph, node);
}
// insert cast for single op.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
@ -211,7 +124,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
kernel_graph->ReplaceInternalOutput(node, new_node);
}
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
return InsertCastForOutput(func_graph, new_node);
}
} // namespace opt
} // namespace mindspore

View File

@ -150,9 +150,6 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
return nullptr;
}
auto next_cnode = next_node->cast<CNodePtr>();
if (AnfAlgo::IsGraphKernel(next_node)) {
return nullptr;
}
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
return nullptr;
@ -224,9 +221,6 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
return nullptr;
}
MS_EXCEPTION_IF_NULL(prior_op);
if (AnfAlgo::IsGraphKernel(prior_op)) {
return nullptr;
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||

View File

@ -1,99 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include <vector>
#include <memory>
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
return cnode;
}
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
if (input_shape.size() != 5) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
return nullptr;
}
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrMultiples);
if (multiples.size() == 4 && multiples[1] == 1) {
multiples.push_back(1);
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
}
return cnode;
}
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name == prim::kPrimTile->name()) {
return ModifyTileOpAttrs(cnode);
} else if (op_name == prim::kPrimReduceSum->name()) {
// kPrimReduceMean
// kPrimReduceSum
// kPrimReduceAll
// kPrimReduceMax
// kPrimReduceMin
return ModifyReduceOpsAttrs(cnode);
}
return nullptr;
}
} // namespace
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ModifyOpAttrs : public PatternProcessPass {
public:
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
~ModifyOpAttrs() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H

View File

@ -1,66 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include <vector>
#include <memory>
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != prim::kPrimReshape->name()) {
return nullptr;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
return cnode->input(1);
}
} // namespace
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class RemoveNoUseReshapeOp : public PatternProcessPass {
public:
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
~RemoveNoUseReshapeOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H

View File

@ -122,9 +122,6 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<CNodePtr> cast_nodes;

View File

@ -596,20 +596,21 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
prim::KPrimTransData};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
prim::kPrimLess};
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
prim::kPrimLess, prim::kPrimInplaceAssign};
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif

View File

@ -81,13 +81,9 @@ const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const
if (iter == MarkOp.end()) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
return nullptr;
} else {
auto cnode = node->cast<CNodePtr>();
AddAttrTraining(func_graph, cnode);
return cnode;
}
auto cnode = node->cast<CNodePtr>();
AddAttrTraining(func_graph, cnode);
return cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -15,7 +15,6 @@
*/
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include <vector>
#include <string>
#include "utils/check_convert_utils.h"
@ -29,32 +28,22 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
std::vector<AnfNodePtr> todos;
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
} else {
todos.push_back(node);
}
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
auto inputs = cnode->inputs();
AnfNodePtr op = inputs[0];
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
auto attrs = prim->attrs();
std::string type_name = prim->name();
for (auto attr : attrs) {
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
if (converted) {
prim->set_attr(attr.first, attr.second);
}
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
if (converted_ir_attr) {
prim->set_attr(attr.first, attr.second);
}
CNodePtr cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
AnfNodePtr op = inputs[0];
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
auto attrs = prim->attrs();
std::string type_name = prim->name();
for (auto attr : attrs) {
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
if (converted) {
prim->set_attr(attr.first, attr.second);
}
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
if (converted_ir_attr) {
prim->set_attr(attr.first, attr.second);
}
}
}

View File

@ -15,7 +15,6 @@
*/
#include "backend/optimizer/pass/convert_const_input_to_attr.h"
#include <vector>
#include <string>
#include <memory>
@ -34,40 +33,31 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
std::vector<AnfNodePtr> todos;
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
} else {
todos.push_back(node);
}
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
continue;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
continue;
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
continue;
}
}
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
CNodePtr cnode = node->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
return nullptr;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
return nullptr;
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return nullptr;
}
}
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
return nullptr;
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
return node;
}
} // namespace opt

View File

@ -100,24 +100,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
}
return nullptr;
}
AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
for (auto &t : todo) {
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());
if (t_new_node != nullptr && t_new_node != t) {
(void)mng->Replace(t, t_new_node);
}
}
return node;
}
} // namespace
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -129,11 +111,8 @@ const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &fun
if (!node->isa<CNode>()) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
return ProcessGraphKernelOp(node);
} else {
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
}
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
}
} // namespace opt
} // namespace mindspore

View File

@ -95,15 +95,6 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(sub_graph, &todos);
for (auto &t : todos) {
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
}
}
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
return node;
}

View File

@ -170,26 +170,6 @@ const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, c
if (cnode == nullptr || func_graph == nullptr) {
return nullptr;
}
if (AnfAlgo::IsGraphKernel(node)) {
// do eliminate for ops in graph kernel.
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
for (auto &t : todo) {
CNodePtr t_cnode = t->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(t_cnode);
auto t_new_node = DoEliminate(sub_graph, t_cnode);
if (t_new_node != nullptr && t_new_node != t) {
(void)mng->Replace(t, t_new_node);
}
}
return node;
}
// do eliminate for single op.
return DoEliminate(func_graph, cnode);
}
} // namespace opt

View File

@ -30,20 +30,7 @@ const BaseRef EraseVisitAttr::DefinePattern() const {
}
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) {
if (AnfAlgo::IsGraphKernel(node)) {
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
AnfAlgo::EraseNodeAttr(kAttrVisited, t);
}
}
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
} else {
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
}
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
return nullptr;
}
} // namespace opt

View File

@ -949,7 +949,6 @@ void AscendSession::InitRuntimeResource() {
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "HardwareOptimize start!";
opt::AscendBackendOptimization(kernel_graph);
opt::AscendGraphKernelCommonProcess(kernel_graph);
GraphKernelOptimize(kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -433,7 +433,9 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
auto cnode = FuncGraph::NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
CreateKernelInfoFromNewParameter(cnode);
if (AnfAlgo::IsGraphKernel(cnode)) {
CreateKernelInfoFromNewParameter(cnode);
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
@ -443,9 +445,6 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
}
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
if (!AnfAlgo::IsGraphKernel(cnode)) {
return;
}
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -45,10 +45,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>();
}
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
// To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage());
@ -58,11 +54,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
tape_ = std::make_shared<FuncGraph>();
}
tape_->set_stage(primal_graph->stage());
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
dout_ = tape_->add_parameter();
}

View File

@ -194,8 +194,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Incorporation
incorporate_getitem_set_ =
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
"incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
@ -211,19 +209,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// tuple parameter graph transform
call_graph_tuple_transform_ =
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// RowTensor Eliminate
row_tensor_eliminate_ = MakeSubstitution(
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",

View File

@ -112,7 +112,6 @@ class OptimizeIRPassLib {
// Incorporation
SubstitutionPtr incorporate_getitem_set_;
SubstitutionPtr incorporate_getitem_from_param_;
SubstitutionPtr incorporate_call_;
SubstitutionPtr incorporate_call_switch_;
@ -125,16 +124,9 @@ class OptimizeIRPassLib {
// Convert
SubstitutionPtr print_tuple_wrapper_;
// Unused parameter eliminate
SubstitutionPtr unused_parameter_eliminate_;
SubstitutionPtr unused_output_eliminate_;
// tuple parameter graph transform
SubstitutionPtr call_graph_tuple_transform_;
// AddN eliminate
SubstitutionPtr addn_eliminate_;
// RowTensor Eliminate
SubstitutionPtr row_tensor_eliminate_;
@ -214,23 +206,6 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
return IsValueNode<FuncGraph>(inp0);
}
// Check if CNode Input 0 is Func Graph of graph kernel.
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(inp0)) {
auto fg = GetValueNode<FuncGraphPtr>(inp0);
if (fg == nullptr) {
return false;
}
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
return false;
}
// Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {

View File

@ -20,7 +20,6 @@
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
@ -135,25 +134,6 @@ class IncorporateGetitem : public AnfVisitor {
return nullptr;
}
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If graph kernel has muti output, do not split.
// some graph kernel output has EnvInstance node or DeadCode node should split.
auto output = fg_->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
auto outputs = output_cnode->inputs();
int64_t real_output_cnt = 0;
for (size_t i = 1; i < outputs.size(); ++i) {
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
real_output_cnt++;
if (real_output_cnt > 1) {
return nullptr;
}
}
}
}
}
auto new_fg = getitem_transform_(fg_, idx_);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(args_);
@ -184,171 +164,6 @@ class IncorporateGetitem : public AnfVisitor {
internal::GetitemTransform getitem_transform_;
};
class IncorporateGetitemFromParam : public AnfVisitor {
public:
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &param, size_t input_idx) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
args_.push_back(cnode->input(input_idx + 1));
return;
}
for (auto &user : node_users[param]) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// we do not process this case.
args_.push_back(cnode->input(input_idx + 1));
return;
}
}
// update new args.
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
// case 1
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
} else {
// case 2
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
auto fg_output = prev_fg->output();
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
<< " should be a make tuple, but got: " << fg_output->DebugString();
return;
}
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
auto new_getitem =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToLong(output_i))});
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(SizeToLong(output_i)));
new_getitem->input(2)->set_abstract(aptr);
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
args_.push_back(new_getitem);
}
}
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph() == nullptr) {
return nullptr;
}
Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr) {
return nullptr;
}
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto parameters = fg->parameters();
if (parameters.size() != inputs.size() - 1) {
return nullptr;
}
replace_parameters_ = std::vector<bool>(parameters.size(), false);
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
auto node_fg = node->func_graph();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
Process(node_fg, cnode, parameters[i - 1], i - 1);
} else {
args_.push_back(inputs[i]);
}
}
if (!need_update_) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
auto node_users = mng->node_users();
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
size_t curr_input_idx{0};
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
if (!replace_parameters_[param_i]) {
if (parameters[param_i]->abstract() != nullptr) {
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
}
new_parameters.push_back(new_fg_parameters[param_i]);
curr_input_idx++;
continue;
}
// make a new parameter.
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
auto new_param = std::make_shared<Parameter>(new_fg);
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
// update users of new parameter.
for (auto &user : node_users[new_fg_parameters[param_i]]) {
idx_ = -1;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int64Imm>})(user.first);
if (idx_ == -1) {
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
<< " must be tuple getitem here, but got: " << user.first->DebugString();
return nullptr;
}
if (input_i == LongToSize(idx_)) {
for (auto &sub_user : node_users[user.first]) {
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub_user_cnode);
sub_user_cnode->set_input(sub_user.second, new_param);
(void)mng->Replace(sub_user.first, sub_user_cnode);
}
}
}
new_parameters.push_back(new_param);
curr_input_idx++;
}
}
mng->SetParameters(new_fg, new_parameters);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_call = node_fg->NewCNode(args_);
new_call->set_abstract(node->abstract());
return new_call;
}
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); }
void Visit(const CNodePtr &cnode) override {}
void Reset() {
replace_parameters_.clear();
args_.clear();
inputs_num_.clear();
need_update_ = false;
idx_ = -1;
}
private:
std::vector<bool> replace_parameters_{};
std::vector<AnfNodePtr> args_{};
std::vector<size_t> inputs_num_{};
bool need_update_{false};
int64_t idx_{-1};
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
class IncorporateGetitemSwitch : public AnfVisitor {
public:

View File

@ -103,13 +103,6 @@ class InlinerBase : public AnfVisitor {
return nullptr;
}
// Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined.
if (fg->nodes().size() - fg->parameters().size() > 1) {
return nullptr;
}
}
Reset();
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}

View File

@ -205,130 +205,6 @@ class AddNZeroFilter : public AnfVisitor {
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false};
};
// {PrimAddN, {kPrimMakeTuple, Xs}}
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
// case0: AddN(inputs)(inputs size < 2) -> error
// case1: AddN(inputs)(all inputs is ValueNode) -> error
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
class AddNEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
need_update_ = false;
bool changed;
do {
changed = Process(new_fg);
} while (changed);
if (!need_update_) {
return nullptr;
} else {
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
}
bool Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto nodes = TopoSort(func_graph->output());
bool changed = false;
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &tuple_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(tuple_input);
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
auto &tuple_inputs = tuple_input_cnode->inputs();
if (tuple_inputs.size() < 3) {
// case0: inputs size < 2, error
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
}
int64_t valuenode_num = std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0,
[](int64_t accumulator, const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
return accumulator + 1;
} else {
return accumulator;
}
});
if (LongToSize(valuenode_num) == tuple_inputs.size()) {
// case1: all inputs is ValueNode, error
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
}
if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs));
changed = true;
continue;
}
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
if (first_valuenode == tuple_inputs.end()) {
// no ValueNode input found.
continue;
} else {
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
std::vector<AnfNodePtr> make_tuple_new_xs{
NewValueNode(prim::kPrimMakeTuple),
};
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
if (node != *first_valuenode) {
make_tuple_new_xs.push_back(node);
}
});
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add);
changed = true;
continue;
}
}
need_update_ = need_update_ || changed;
return changed;
}
private:
bool need_update_{false};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -22,7 +22,6 @@
#include <memory>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include <tuple>
#include "frontend/optimizer/irpass.h"
@ -162,146 +161,6 @@ class SpecializeOnGraphArguments : public AnfVisitor {
private:
internal::SpecializeTransform specialize_transform_;
};
// Eliminate unused parameters.
// {G, Xs}
class UnusedParasEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> parameters = fg->parameters();
size_t size = parameters.size();
if (size != inputs.size() - 1) {
return nullptr;
}
std::vector<AnfNodePtr> new_xs;
std::vector<bool> keep_parameters;
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
bool has_unused_para = false;
for (size_t i = 0; i < size; ++i) {
auto iter = node_users.find(parameters[i]);
if (iter != node_users.end() && !iter->second.empty()) {
keep_parameters.push_back(true);
new_xs.push_back(inputs[i + 1]);
continue;
}
keep_parameters.push_back(false);
has_unused_para = true;
}
if (!has_unused_para) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
for (size_t i = 0; i < size; i++) {
if (keep_parameters[i]) {
if (parameters[i]->abstract() != nullptr) {
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
}
new_parameters.push_back(new_fg_parameters[i]);
}
}
mng->SetParameters(new_fg, new_parameters);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);
}
};
// Eliminate unused outputs.
// {G, Xs}
class UnusedOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
auto new_fg_output = new_fg->output();
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
return nullptr;
}
auto output_cnode = new_fg_output->cast<CNodePtr>();
auto &node_users = mng->node_users();
if (node_users.count(node) == 0 || node_users[node].empty()) {
return nullptr;
}
std::unordered_set<int64_t> used_output_idx;
std::vector<std::pair<AnfNodePtr, int64_t>> all_users;
for (auto &node_user : node_users[node]) {
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
return nullptr;
}
auto user_cnode = node_user.first->cast<CNodePtr>();
size_t used_idx = GetValue<int64_t>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
used_output_idx.insert(used_idx);
all_users.push_back(std::make_pair(node_user.first, used_idx));
}
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
// all output has users.
return nullptr;
}
if (used_output_idx.empty()) {
// we do not process this case.
return nullptr;
} else if (used_output_idx.size() == 1) {
// after eliminate, only one output left.
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
// update users.
for (auto &ret_user : all_users) {
(void)mng->Replace(ret_user.first, node);
}
} else {
// after eliminate, create new multi output.
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
std::unordered_map<int64_t, int64_t> new_idx_map;
for (auto idx : used_output_idx) {
new_idx_map[idx] = SizeToLong(new_output_inputs.size() - 1);
new_output_inputs.push_back(output_cnode->input(idx + 1));
}
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
// update users.
for (auto &ret_user : all_users) {
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
}
}
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -39,7 +39,6 @@
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "utils/context/graph_kernel_flags.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h"
@ -296,31 +295,6 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i
return map;
}
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap map({
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
});
return map;
}
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig elim_1 = opt::OptPassConfig({
irpass.addn_eliminate_,
irpass.incorporate_getitem_from_param_,
});
opt::OptPassConfig elim_2 = opt::OptPassConfig({
irpass.unused_parameter_eliminate_,
irpass.unused_output_eliminate_,
});
OptPassGroupMap map({
{"elim_1", elim_1},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"elim_2", elim_2},
});
return map;
}
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
}
@ -375,10 +349,6 @@ void InitOpt(const ResourcePtr &res) {
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_trans_graph"] =
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_grad_epilogue"] =
@ -386,10 +356,6 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
}
}
}
} // namespace
@ -424,8 +390,6 @@ bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a");
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
@ -559,8 +523,6 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_cache_embedding", AddCacheEmbeddingPass},
{"add_recomputation", AddRecomputationPass},
{"cse_after_recomputation", OptAfterRecomputeGroup}};

View File

@ -216,9 +216,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
}
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));

View File

@ -602,12 +602,6 @@ bool GraphPartition::IsCut(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = inputs[0];
if (IsValueNode<FuncGraph>(fn)) {
auto fg = GetValueNode<FuncGraphPtr>(fn);
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
return false;
}
}
if (!IsValueNode<Primitive>(fn)) {
return true;
}

View File

@ -510,7 +510,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(prim_graph);
FuncGraphSet graphs = prim_graph->manager()->func_graphs();
for (auto g : graphs) {
if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
if (g != graph && g != nullptr) {
Compile(g);
}
}
@ -542,7 +542,7 @@ uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
// Compile sub graphs.
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
for (auto sub_graph : sub_graphs) {
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
if (sub_graph != func_graph && sub_graph != nullptr) {
(void)CompileGraph(sub_graph);
}
}

View File

@ -23,6 +23,7 @@ import numpy
from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.common._decorator import deprecated
from mindspore.context import ParallelMode
from .. import context
from .._c_expression import init_pipeline, Cell_, FuncGraph
@ -1213,6 +1214,11 @@ class GraphKernel(Cell):
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
enable_graph_kernel in context is set to True.
This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead.
GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as
normal `Cell` objects.
Args:
auto_prefix (bool): Recursively generate namespaces. Default: True.
flags (dict) : Set graph flags. Default: None.
@ -1230,10 +1236,9 @@ class GraphKernel(Cell):
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
"""
@deprecated("1.3", "Cell", True)
def __init__(self, auto_prefix=True, flags=None):
super(GraphKernel, self).__init__(auto_prefix, flags)
class_name = self.__class__.__name__
self.add_flags(graph_kernel=class_name)
def construct(self):
raise NotImplementedError

View File

@ -16,7 +16,6 @@
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
@ -87,7 +86,7 @@ class Softmax(Cell):
def __init__(self, axis=-1):
super(Softmax, self).__init__()
self.softmax = _selected_ops.Softmax(axis)
self.softmax = P.Softmax(axis)
def construct(self, x):
return self.softmax(x)
@ -137,7 +136,7 @@ class LogSoftmax(Cell):
def __init__(self, axis=-1):
super(LogSoftmax, self).__init__()
self.log_softmax = _selected_ops.LogSoftmax(axis)
self.log_softmax = P.LogSoftmax(axis)
def construct(self, x):
return self.log_softmax(x)
@ -368,7 +367,7 @@ class Tanh(Cell):
def __init__(self):
super(Tanh, self).__init__()
self.tanh = _selected_ops.Tanh()
self.tanh = P.Tanh()
def construct(self, x):
return self.tanh(x)
@ -415,7 +414,7 @@ class GELU(Cell):
def __init__(self):
super(GELU, self).__init__()
self.gelu = _selected_ops.GeLU()
self.gelu = P.GeLU()
def construct(self, x):
return self.gelu(x)
@ -458,7 +457,7 @@ class FastGelu(Cell):
def __init__(self):
super(FastGelu, self).__init__()
self.fast_gelu = _selected_ops.FastGeLU()
self.fast_gelu = P.FastGeLU()
def construct(self, x):
return self.fast_gelu(x)

View File

@ -30,7 +30,6 @@ from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore.ops import _selected_ops
from mindspore.common import dtype as mstype
from ..cell import Cell
@ -837,9 +836,9 @@ class LayerNorm(Cell):
gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta")
self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
begin_params_axis=self.begin_params_axis,
epsilon=self.epsilon)
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
begin_params_axis=self.begin_params_axis,
epsilon=self.epsilon)
def construct(self, input_x):
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)

View File

@ -21,7 +21,6 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import nn
from mindspore.ops.primitive import constexpr
from mindspore.ops import _selected_ops
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore._checkparam import Validator as validator
@ -48,7 +47,7 @@ class _Loss(Cell):
if reduction == 'none':
self.reduce = False
self.reduce_mean = _selected_ops.ReduceMean()
self.reduce_mean = P.ReduceMean()
self.reduce_sum = P.ReduceSum()
self.mul = P.Mul()
self.cast = P.Cast()
@ -381,7 +380,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = validator.check_bool(sparse, "sparse")
self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32)

View File

@ -14,7 +14,6 @@
# ============================================================================
"""momentum"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
@ -157,7 +156,7 @@ class Momentum(Optimizer):
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
def construct(self, gradients):
params = self.params

View File

@ -19,7 +19,6 @@ from functools import reduce
import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore.ops import _selected_grad_ops as SG
from .. import functional as F
from .. import operations as P
from ..operations import _grad_ops as G
@ -590,7 +589,7 @@ def get_bprop_expm1(self):
@bprop_getters.register(P.Minimum)
def get_bprop_minimum(self):
"""Grad definition for `Minimum` operation."""
input_grad = SG.MinimumGrad()
input_grad = G.MinimumGrad()
def bprop(x, y, out, dout):
dx, dy = input_grad(x, y, dout)
@ -602,7 +601,7 @@ def get_bprop_minimum(self):
@bprop_getters.register(P.Maximum)
def get_bprop_maximum(self):
"""Grad definition for `Maximum` operation."""
input_grad = SG.MaximumGrad()
input_grad = G.MaximumGrad()
def bprop(x, y, out, dout):
dx, dy = input_grad(x, y, dout)
@ -1107,7 +1106,7 @@ def get_bprop_cosh(self):
@bprop_getters.register(P.Abs)
def get_bprop_abs(self):
"""Grad definition for `Abs` operation."""
abs_grad = SG.AbsGrad()
abs_grad = G.AbsGrad()
def bprop(x, out, dout):
dx = abs_grad(x, dout)

View File

@ -15,7 +15,6 @@
"""Define the grad rules of neural network related operations."""
import os
from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor
from mindspore.ops.operations import nn_ops as nps
@ -34,7 +33,7 @@ env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
@bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation."""
bias_grad = SG.BiasAddGrad(self.data_format)
bias_grad = G.BiasAddGrad(self.data_format)
def bprop(x, w, out, dout):
return dout, bias_grad(dout)
@ -341,7 +340,7 @@ def get_bprop_dropout_do_mask(self):
def get_bprop_mish(self):
"""Grad definition for `Mish` operation."""
tanh = P.Tanh()
tanh_grad = SG.TanhGrad()
tanh_grad = G.TanhGrad()
softplus = P.Softplus()
softplus_grad = G.SoftplusGrad()
@ -580,7 +579,7 @@ def get_bprop_softsign(self):
@bprop_getters.register(P.Tanh)
def get_bprop_tanh(self):
"""Grad definition for `Tanh` operation."""
tanh_grad = SG.TanhGrad()
tanh_grad = G.TanhGrad()
def bprop(x, out, dout):
dx = tanh_grad(out, dout)

View File

@ -1,50 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" resolved grad ops """
from mindspore.ops.op_selector import new_ops_selector
op_selector = new_ops_selector(
"mindspore.ops.operations._grad_ops", "mindspore.nn._graph_kernels")
@op_selector
class MaximumGrad:
def __call__(self, *args):
pass
@op_selector
class MinimumGrad:
def __call__(self, *args):
pass
@op_selector
class AbsGrad:
def __call__(self, *args):
pass
@op_selector
class BiasAddGrad:
def __call__(self, *args):
pass
@op_selector
class TanhGrad:
def __call__(self, *args):
pass

View File

@ -1,120 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" resolve ops """
from mindspore.ops.op_selector import new_ops_selector
op_selector = new_ops_selector(
"mindspore.ops.operations", "mindspore.nn._graph_kernels")
opt_selector = new_ops_selector(
"mindspore.nn.optim", "mindspore.nn._graph_kernels")
nn_selector = new_ops_selector(
"mindspore.nn", "mindspore.nn._graph_kernels")
@nn_selector
class BatchNorm2d:
def __call__(self, *args):
pass
@op_selector
class ReLU:
def __call__(self, *args):
pass
@op_selector
class ReduceMean:
def __call__(self, *args):
pass
@op_selector
class BiasAdd:
def __call__(self, *args):
pass
@op_selector
class ApplyMomentum:
def __call__(self, *args):
pass
@op_selector
class SoftmaxCrossEntropyWithLogits:
def __call__(self, *args):
pass
@op_selector
class LogSoftmax:
def __call__(self, *args):
pass
@op_selector
class Tanh:
def __call__(self, *args):
pass
@op_selector
class GeLU:
def __call__(self, *args):
pass
@op_selector
class FastGeLU:
def __call__(self, *args):
pass
@op_selector
class LayerNorm:
def __call__(self, *args):
pass
@op_selector
class Softmax:
def __call__(self, *args):
pass
@op_selector
class LambUpdateWithLR:
def __call__(self, *args):
pass
@op_selector
class LambNextMV:
def __call__(self, *args):
pass
@op_selector
class LambApplyOptimizerAssign:
def __call__(self, *args):
pass
@op_selector
class LambApplyWeightAssign:
def __call__(self, *args):
pass

View File

@ -20,6 +20,7 @@ which can be used to control the switch of op type: GraphKernel or Primitive.
import importlib
import inspect
from mindspore import context
from mindspore.common._decorator import deprecated
class _OpSelector:
@ -70,6 +71,7 @@ class _OpSelector:
return op(*args, **kwargs)
@deprecated("1.3", "basic Primitive", False)
def new_ops_selector(primitive_pkg, graph_kernel_pkg):
"""
A factory method to return an op selector
@ -83,6 +85,8 @@ def new_ops_selector(primitive_pkg, graph_kernel_pkg):
The order of the highest priority to lowest priority is (1), (2), (3)
If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
The user-defined GraphKernel Cell is deprecated, this interface will be removed in a future version.
Args:
primitive_pkg (str): primitive op's package name
graph_kernel_pkg (str): graph kernel op's package name

View File

@ -16,6 +16,7 @@
"""Other operators."""
import functools
from mindspore.common import monad
from mindspore.common._decorator import deprecated
from .. import signature as sig
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
@ -84,6 +85,8 @@ class InplaceAssign(PrimitiveWithInfer):
Inplace assign `Parameter` with a value.
This primitive can only use in graph kernel.
InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead.
Inputs:
- **variable** (Parameter) - The `Parameter`.
- **value** (Tensor) - The value to be assigned.
@ -110,7 +113,8 @@ class InplaceAssign(PrimitiveWithInfer):
>>> net = Net()
>>> output = net(x)
>>> print(output)
"""
"""
@deprecated("1.3", "Assign", False)
@ prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])

View File

@ -26,7 +26,6 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import _selected_ops
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model
from mindspore.context import ParallelMode
@ -76,7 +75,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = sparse
self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32)