Clean GraphKernel's codes from frontend

1. set class GraphKernel as deprecated, and treat it as Cell
2. set class InplaceAssign as deprecated, suggested using Assign instead.
3. set op_selector as deprecated, removed the _selected_ops and _selected_grad_ops, replaced with real operations
4. removed the two passes of GraphKernel from frontend
5. removed the GraphKernel's codes from other modules
This commit is contained in:
dayschan 2021-04-06 16:19:43 +08:00
parent c3f3fcab3f
commit 771e3f61f3
44 changed files with 124 additions and 1278 deletions

View File

@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <memory> #include <memory>
#include <vector>
#include <string> #include <string>
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h" #include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
@ -110,8 +111,6 @@
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" #include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "backend/optimizer/ascend/ir_fission/split_fission.h" #include "backend/optimizer/ascend/ir_fission/split_fission.h"
#include "backend/optimizer/ascend/ir_fission/splitv_fission.h" #include "backend/optimizer/ascend/ir_fission/splitv_fission.h"
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" #include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h" #include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "backend/optimizer/ascend/ir_fission/concat_fission.h" #include "backend/optimizer/ascend/ir_fission/concat_fission.h"
@ -199,19 +198,6 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>()); ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
} }
} // namespace } // namespace
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
MS_EXCEPTION_IF_NULL(optimizer);
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
MS_EXCEPTION_IF_NULL(common_process);
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
optimizer->AddPassManager(common_process);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();

View File

@ -24,7 +24,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
} // namespace opt } // namespace opt

View File

@ -344,7 +344,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0); auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
auto real_input_node = kernel_with_index.first; auto real_input_node = kernel_with_index.first;
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { if (kernel::IsWeightBoundary(real_input_node)) {
// weight // weight
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
if (origin_type == kTypeUnknown) { if (origin_type == kTypeUnknown) {
@ -358,9 +358,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second); const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
// In graph kernel, we check parameter, // In graph kernel, we check parameter,
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast. // the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) { if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
new_inputs.push_back(cur_input);
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
auto cast = auto cast =
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(cast); MS_EXCEPTION_IF_NULL(cast);

View File

@ -78,21 +78,12 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> todos = {node}; CNodePtr cnode = node->cast<CNodePtr>();
if (AnfAlgo::IsGraphKernel(node)) { size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); for (size_t i = 0; i < in_num; ++i) {
MS_EXCEPTION_IF_NULL(sub_graph); if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
kernel::GetValidKernelNodes(sub_graph, &todos); MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
} << cnode->DebugString() << "]";
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < in_num; ++i) {
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
<< cnode->DebugString() << "]";
}
} }
} }
return nullptr; return nullptr;

View File

@ -30,8 +30,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
@ -42,32 +41,23 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) { for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
AnfNodePtr replace_node = nullptr; AnfNodePtr replace_node = nullptr;
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); const auto origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
auto idx = NewValueNode(SizeToLong(output_idx)); auto idx = NewValueNode(SizeToLong(output_idx));
MS_EXCEPTION_IF_NULL(idx); MS_EXCEPTION_IF_NULL(idx);
auto imm = std::make_shared<Int64Imm>(output_idx); auto imm = std::make_shared<Int64Imm>(output_idx);
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm)); idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get());
if (need_insert_cast[output_idx]) { const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
TypeId origin_type(kTypeUnknown); if (origin_type != device_type) {
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); origin_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
} MS_EXCEPTION_IF_NULL(replace_node);
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; replace_node->set_scope(cnode->scope());
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (origin_type != device_type) { if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
infer_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
} else {
replace_node = getitem;
} }
} else { } else {
replace_node = getitem; replace_node = getitem;
@ -81,8 +71,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
return make_tuple; return make_tuple;
} // namespace } // namespace
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
const std::vector<bool> &need_insert_cast) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
@ -92,23 +81,14 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// Single output // Single output
if (!cnode->Type()->isa<Tuple>()) { if (!cnode->Type()->isa<Tuple>()) {
if (!need_insert_cast[0]) {
return cnode;
}
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
TypeId origin_type(kTypeUnknown);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
}
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
AnfNodePtr replace_node = cnode; AnfNodePtr replace_node = cnode;
if (origin_type != device_type) { if (origin_type != device_type) {
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
infer_type, AnfAlgo::GetOutputReshapeType(cnode, 0)); origin_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope()); replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
@ -119,69 +99,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return replace_node; return replace_node;
} }
// Multiple output // Multiple output
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); return InsertCastForMultipleOutput(func_graph, cnode);
}
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
// insert cast for ops in graph kernel.
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
for (auto &output : outputs) {
size_t index = 0;
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(tuple_index_value);
if (!tuple_index_value->isa<Int64Imm>()) {
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
}
index = tuple_index_value->cast<Int64ImmPtr>()->value();
}
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
}
for (auto &t : todo) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
// process input
CNodePtr t_cnode = t->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(t_cnode);
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
AnfNodePtr t_new_node_1 = nullptr;
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
// process output
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
if (iter != graph_rets.end()) {
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
need_insert_cast[iter->second] = false;
}
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
} else {
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
}
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
(void)mng->Replace(t, t_new_node_1);
}
}
// insert cast for graph kernel.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode);
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
} }
} // namespace } // namespace
@ -196,11 +114,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
return ProcessGraphKernelOp(func_graph, node);
}
// insert cast for single op.
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input // process input
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
@ -211,7 +124,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
kernel_graph->ReplaceInternalOutput(node, new_node); kernel_graph->ReplaceInternalOutput(node, new_node);
} }
// process output // process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); return InsertCastForOutput(func_graph, new_node);
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -150,9 +150,6 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
return nullptr; return nullptr;
} }
auto next_cnode = next_node->cast<CNodePtr>(); auto next_cnode = next_node->cast<CNodePtr>();
if (AnfAlgo::IsGraphKernel(next_node)) {
return nullptr;
}
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode); auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) { if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
return nullptr; return nullptr;
@ -224,9 +221,6 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
return nullptr; return nullptr;
} }
MS_EXCEPTION_IF_NULL(prior_op); MS_EXCEPTION_IF_NULL(prior_op);
if (AnfAlgo::IsGraphKernel(prior_op)) {
return nullptr;
}
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() || if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||

View File

@ -1,99 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include <vector>
#include <memory>
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
return cnode;
}
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
if (input_shape.size() != 5) {
return nullptr;
}
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
return nullptr;
}
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrMultiples);
if (multiples.size() == 4 && multiples[1] == 1) {
multiples.push_back(1);
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
}
return cnode;
}
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name == prim::kPrimTile->name()) {
return ModifyTileOpAttrs(cnode);
} else if (op_name == prim::kPrimReduceSum->name()) {
// kPrimReduceMean
// kPrimReduceSum
// kPrimReduceAll
// kPrimReduceMax
// kPrimReduceMin
return ModifyReduceOpsAttrs(cnode);
}
return nullptr;
}
} // namespace
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ModifyOpAttrs : public PatternProcessPass {
public:
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
~ModifyOpAttrs() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H

View File

@ -1,66 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include <vector>
#include <memory>
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode);
if (op_name != prim::kPrimReshape->name()) {
return nullptr;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
return nullptr;
}
return cnode->input(1);
}
} // namespace
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
if (new_node != nullptr && new_node != t) {
(void)manager->Replace(t, new_node);
}
}
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,33 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class RemoveNoUseReshapeOp : public PatternProcessPass {
public:
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
~RemoveNoUseReshapeOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H

View File

@ -122,9 +122,6 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::vector<CNodePtr> cast_nodes; std::vector<CNodePtr> cast_nodes;

View File

@ -596,20 +596,21 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList() { std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D #if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = { std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData}; prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
prim::KPrimTransData};
#elif ENABLE_GPU #elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = { std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin, prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
prim::kPrimLess}; prim::kPrimLess, prim::kPrimInplaceAssign};
#else #else
std::vector<PrimitivePtr> fusible_basic_ops; std::vector<PrimitivePtr> fusible_basic_ops;
#endif #endif

View File

@ -81,13 +81,9 @@ const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const
if (iter == MarkOp.end()) { if (iter == MarkOp.end()) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) { auto cnode = node->cast<CNodePtr>();
return nullptr; AddAttrTraining(func_graph, cnode);
} else { return cnode;
auto cnode = node->cast<CNodePtr>();
AddAttrTraining(func_graph, cnode);
return cnode;
}
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -15,7 +15,6 @@
*/ */
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h" #include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
#include <vector>
#include <string> #include <string>
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
@ -29,32 +28,22 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> todos;
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
} else {
todos.push_back(node);
}
for (auto &t : todos) { CNodePtr cnode = node->cast<CNodePtr>();
CNodePtr cnode = t->cast<CNodePtr>(); auto inputs = cnode->inputs();
auto inputs = cnode->inputs(); AnfNodePtr op = inputs[0];
AnfNodePtr op = inputs[0]; if (IsValueNode<Primitive>(op)) {
if (IsValueNode<Primitive>(op)) { auto prim = GetValueNode<PrimitivePtr>(op);
auto prim = GetValueNode<PrimitivePtr>(op); auto attrs = prim->attrs();
auto attrs = prim->attrs(); std::string type_name = prim->name();
std::string type_name = prim->name(); for (auto attr : attrs) {
for (auto attr : attrs) { bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second); if (converted) {
if (converted) { prim->set_attr(attr.first, attr.second);
prim->set_attr(attr.first, attr.second); }
} bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second); if (converted_ir_attr) {
if (converted_ir_attr) { prim->set_attr(attr.first, attr.second);
prim->set_attr(attr.first, attr.second);
}
} }
} }
} }

View File

@ -15,7 +15,6 @@
*/ */
#include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "backend/optimizer/pass/convert_const_input_to_attr.h"
#include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
@ -34,40 +33,31 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> todos;
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
kernel::GetValidKernelNodes(sub_graph, &todos);
} else {
todos.push_back(node);
}
for (auto &t : todos) { CNodePtr cnode = node->cast<CNodePtr>();
CNodePtr cnode = t->cast<CNodePtr>(); ConstInputToAttrInfoRegister reg;
ConstInputToAttrInfoRegister reg; if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) { return nullptr;
continue;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
continue;
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
continue;
}
}
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
} }
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
return nullptr;
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
auto ms_context = MsContext::GetInstance();
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return nullptr;
}
}
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
return nullptr;
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
return node; return node;
} }
} // namespace opt } // namespace opt

View File

@ -100,24 +100,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
} }
return nullptr; return nullptr;
} }
AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
for (auto &t : todo) {
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());
if (t_new_node != nullptr && t_new_node != t) {
(void)mng->Replace(t, t_new_node);
}
}
return node;
}
} // namespace } // namespace
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
@ -129,11 +111,8 @@ const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &fun
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
return ProcessGraphKernelOp(node); return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
} else {
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
}
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -95,15 +95,6 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(sub_graph, &todos);
for (auto &t : todos) {
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
}
}
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>()); ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
return node; return node;
} }

View File

@ -170,26 +170,6 @@ const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, c
if (cnode == nullptr || func_graph == nullptr) { if (cnode == nullptr || func_graph == nullptr) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsGraphKernel(node)) {
// do eliminate for ops in graph kernel.
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(sub_graph);
auto mng = sub_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
std::vector<AnfNodePtr> todo;
kernel::GetValidKernelNodes(sub_graph, &todo);
for (auto &t : todo) {
CNodePtr t_cnode = t->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(t_cnode);
auto t_new_node = DoEliminate(sub_graph, t_cnode);
if (t_new_node != nullptr && t_new_node != t) {
(void)mng->Replace(t, t_new_node);
}
}
return node;
}
// do eliminate for single op.
return DoEliminate(func_graph, cnode); return DoEliminate(func_graph, cnode);
} }
} // namespace opt } // namespace opt

View File

@ -30,20 +30,7 @@ const BaseRef EraseVisitAttr::DefinePattern() const {
} }
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) { AnfAlgo::EraseNodeAttr(kAttrVisited, node);
if (AnfAlgo::IsGraphKernel(node)) {
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> todos;
kernel::GetValidKernelNodes(fg, &todos);
for (auto &t : todos) {
AnfAlgo::EraseNodeAttr(kAttrVisited, t);
}
}
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
} else {
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
}
return nullptr; return nullptr;
} }
} // namespace opt } // namespace opt

View File

@ -949,7 +949,6 @@ void AscendSession::InitRuntimeResource() {
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "HardwareOptimize start!"; MS_LOG(INFO) << "HardwareOptimize start!";
opt::AscendBackendOptimization(kernel_graph); opt::AscendBackendOptimization(kernel_graph);
opt::AscendGraphKernelCommonProcess(kernel_graph);
GraphKernelOptimize(kernel_graph); GraphKernelOptimize(kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();

View File

@ -433,7 +433,9 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
auto cnode = FuncGraph::NewCNode(inputs); auto cnode = FuncGraph::NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(std::make_shared<abstract::AbstractNone>()); cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
CreateKernelInfoFromNewParameter(cnode); if (AnfAlgo::IsGraphKernel(cnode)) {
CreateKernelInfoFromNewParameter(cnode);
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
} }
@ -443,9 +445,6 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
} }
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
if (!AnfAlgo::IsGraphKernel(cnode)) {
return;
}
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -45,10 +45,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>(); k_graph_ = std::make_shared<FuncGraph>();
} }
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
// To keep switch_layer's inputs from being inlined // To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input()); k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage()); k_graph_->set_stage(primal_graph->stage());
@ -58,11 +54,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
tape_ = std::make_shared<FuncGraph>(); tape_ = std::make_shared<FuncGraph>();
} }
tape_->set_stage(primal_graph->stage()); tape_->set_stage(primal_graph->stage());
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
}
dout_ = tape_->add_parameter(); dout_ = tape_->add_parameter();
} }

View File

@ -194,8 +194,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Incorporation // Incorporation
incorporate_getitem_set_ = incorporate_getitem_set_ =
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem); MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
"incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup); incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup); MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
@ -211,19 +209,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
print_tuple_wrapper_ = print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint); MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// tuple parameter graph transform // tuple parameter graph transform
call_graph_tuple_transform_ = call_graph_tuple_transform_ =
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode); MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// RowTensor Eliminate // RowTensor Eliminate
row_tensor_eliminate_ = MakeSubstitution( row_tensor_eliminate_ = MakeSubstitution(
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate", std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",

View File

@ -112,7 +112,6 @@ class OptimizeIRPassLib {
// Incorporation // Incorporation
SubstitutionPtr incorporate_getitem_set_; SubstitutionPtr incorporate_getitem_set_;
SubstitutionPtr incorporate_getitem_from_param_;
SubstitutionPtr incorporate_call_; SubstitutionPtr incorporate_call_;
SubstitutionPtr incorporate_call_switch_; SubstitutionPtr incorporate_call_switch_;
@ -125,16 +124,9 @@ class OptimizeIRPassLib {
// Convert // Convert
SubstitutionPtr print_tuple_wrapper_; SubstitutionPtr print_tuple_wrapper_;
// Unused parameter eliminate
SubstitutionPtr unused_parameter_eliminate_;
SubstitutionPtr unused_output_eliminate_;
// tuple parameter graph transform // tuple parameter graph transform
SubstitutionPtr call_graph_tuple_transform_; SubstitutionPtr call_graph_tuple_transform_;
// AddN eliminate
SubstitutionPtr addn_eliminate_;
// RowTensor Eliminate // RowTensor Eliminate
SubstitutionPtr row_tensor_eliminate_; SubstitutionPtr row_tensor_eliminate_;
@ -214,23 +206,6 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
return IsValueNode<FuncGraph>(inp0); return IsValueNode<FuncGraph>(inp0);
} }
// Check if CNode Input 0 is Func Graph of graph kernel.
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(inp0)) {
auto fg = GetValueNode<FuncGraphPtr>(inp0);
if (fg == nullptr) {
return false;
}
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
}
return false;
}
// Check if CNode Input 0 is CNode // Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) { inline bool IsCNodeDup(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {

View File

@ -20,7 +20,6 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include <utility> #include <utility>
@ -135,25 +134,6 @@ class IncorporateGetitem : public AnfVisitor {
return nullptr; return nullptr;
} }
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If graph kernel has muti output, do not split.
// some graph kernel output has EnvInstance node or DeadCode node should split.
auto output = fg_->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
auto outputs = output_cnode->inputs();
int64_t real_output_cnt = 0;
for (size_t i = 1; i < outputs.size(); ++i) {
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
real_output_cnt++;
if (real_output_cnt > 1) {
return nullptr;
}
}
}
}
}
auto new_fg = getitem_transform_(fg_, idx_); auto new_fg = getitem_transform_(fg_, idx_);
(void)args_.insert(args_.begin(), NewValueNode(new_fg)); (void)args_.insert(args_.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(args_); return node->func_graph()->NewCNode(args_);
@ -184,171 +164,6 @@ class IncorporateGetitem : public AnfVisitor {
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
class IncorporateGetitemFromParam : public AnfVisitor {
public:
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &param, size_t input_idx) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
args_.push_back(cnode->input(input_idx + 1));
return;
}
for (auto &user : node_users[param]) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// we do not process this case.
args_.push_back(cnode->input(input_idx + 1));
return;
}
}
// update new args.
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
// case 1
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
} else {
// case 2
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
auto fg_output = prev_fg->output();
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
<< " should be a make tuple, but got: " << fg_output->DebugString();
return;
}
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
auto new_getitem =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToLong(output_i))});
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(SizeToLong(output_i)));
new_getitem->input(2)->set_abstract(aptr);
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
args_.push_back(new_getitem);
}
}
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph() == nullptr) {
return nullptr;
}
Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr) {
return nullptr;
}
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto parameters = fg->parameters();
if (parameters.size() != inputs.size() - 1) {
return nullptr;
}
replace_parameters_ = std::vector<bool>(parameters.size(), false);
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
auto node_fg = node->func_graph();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
Process(node_fg, cnode, parameters[i - 1], i - 1);
} else {
args_.push_back(inputs[i]);
}
}
if (!need_update_) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
auto node_users = mng->node_users();
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
size_t curr_input_idx{0};
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
if (!replace_parameters_[param_i]) {
if (parameters[param_i]->abstract() != nullptr) {
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
}
new_parameters.push_back(new_fg_parameters[param_i]);
curr_input_idx++;
continue;
}
// make a new parameter.
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
auto new_param = std::make_shared<Parameter>(new_fg);
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
// update users of new parameter.
for (auto &user : node_users[new_fg_parameters[param_i]]) {
idx_ = -1;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int64Imm>})(user.first);
if (idx_ == -1) {
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
<< " must be tuple getitem here, but got: " << user.first->DebugString();
return nullptr;
}
if (input_i == LongToSize(idx_)) {
for (auto &sub_user : node_users[user.first]) {
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub_user_cnode);
sub_user_cnode->set_input(sub_user.second, new_param);
(void)mng->Replace(sub_user.first, sub_user_cnode);
}
}
}
new_parameters.push_back(new_param);
curr_input_idx++;
}
}
mng->SetParameters(new_fg, new_parameters);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_call = node_fg->NewCNode(args_);
new_call->set_abstract(node->abstract());
return new_call;
}
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); }
void Visit(const CNodePtr &cnode) override {}
void Reset() {
replace_parameters_.clear();
args_.clear();
inputs_num_.clear();
need_update_ = false;
idx_ = -1;
}
private:
std::vector<bool> replace_parameters_{};
std::vector<AnfNodePtr> args_{};
std::vector<size_t> inputs_num_{};
bool need_update_{false};
int64_t idx_{-1};
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
class IncorporateGetitemSwitch : public AnfVisitor { class IncorporateGetitemSwitch : public AnfVisitor {
public: public:

View File

@ -103,13 +103,6 @@ class InlinerBase : public AnfVisitor {
return nullptr; return nullptr;
} }
// Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined.
if (fg->nodes().size() - fg->parameters().size() > 1) {
return nullptr;
}
}
Reset(); Reset();
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...} // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}

View File

@ -205,130 +205,6 @@ class AddNZeroFilter : public AnfVisitor {
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false}; bool has_zero_like_{false};
}; };
// {PrimAddN, {kPrimMakeTuple, Xs}}
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
// case0: AddN(inputs)(inputs size < 2) -> error
// case1: AddN(inputs)(all inputs is ValueNode) -> error
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
class AddNEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
need_update_ = false;
bool changed;
do {
changed = Process(new_fg);
} while (changed);
if (!need_update_) {
return nullptr;
} else {
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
}
bool Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto nodes = TopoSort(func_graph->output());
bool changed = false;
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &tuple_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(tuple_input);
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
auto &tuple_inputs = tuple_input_cnode->inputs();
if (tuple_inputs.size() < 3) {
// case0: inputs size < 2, error
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
}
int64_t valuenode_num = std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0,
[](int64_t accumulator, const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
return accumulator + 1;
} else {
return accumulator;
}
});
if (LongToSize(valuenode_num) == tuple_inputs.size()) {
// case1: all inputs is ValueNode, error
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
}
if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs));
changed = true;
continue;
}
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
if (first_valuenode == tuple_inputs.end()) {
// no ValueNode input found.
continue;
} else {
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
std::vector<AnfNodePtr> make_tuple_new_xs{
NewValueNode(prim::kPrimMakeTuple),
};
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
if (node != *first_valuenode) {
make_tuple_new_xs.push_back(node);
}
});
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add);
changed = true;
continue;
}
}
need_update_ = need_update_ || changed;
return changed;
}
private:
bool need_update_{false};
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -22,7 +22,6 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <tuple> #include <tuple>
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
@ -162,146 +161,6 @@ class SpecializeOnGraphArguments : public AnfVisitor {
private: private:
internal::SpecializeTransform specialize_transform_; internal::SpecializeTransform specialize_transform_;
}; };
// Eliminate unused parameters.
// {G, Xs}
class UnusedParasEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> parameters = fg->parameters();
size_t size = parameters.size();
if (size != inputs.size() - 1) {
return nullptr;
}
std::vector<AnfNodePtr> new_xs;
std::vector<bool> keep_parameters;
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
bool has_unused_para = false;
for (size_t i = 0; i < size; ++i) {
auto iter = node_users.find(parameters[i]);
if (iter != node_users.end() && !iter->second.empty()) {
keep_parameters.push_back(true);
new_xs.push_back(inputs[i + 1]);
continue;
}
keep_parameters.push_back(false);
has_unused_para = true;
}
if (!has_unused_para) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
for (size_t i = 0; i < size; i++) {
if (keep_parameters[i]) {
if (parameters[i]->abstract() != nullptr) {
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
}
new_parameters.push_back(new_fg_parameters[i]);
}
}
mng->SetParameters(new_fg, new_parameters);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);
}
};
// Eliminate unused outputs.
// {G, Xs}
class UnusedOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
auto new_fg_output = new_fg->output();
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
return nullptr;
}
auto output_cnode = new_fg_output->cast<CNodePtr>();
auto &node_users = mng->node_users();
if (node_users.count(node) == 0 || node_users[node].empty()) {
return nullptr;
}
std::unordered_set<int64_t> used_output_idx;
std::vector<std::pair<AnfNodePtr, int64_t>> all_users;
for (auto &node_user : node_users[node]) {
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
return nullptr;
}
auto user_cnode = node_user.first->cast<CNodePtr>();
size_t used_idx = GetValue<int64_t>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
used_output_idx.insert(used_idx);
all_users.push_back(std::make_pair(node_user.first, used_idx));
}
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
// all output has users.
return nullptr;
}
if (used_output_idx.empty()) {
// we do not process this case.
return nullptr;
} else if (used_output_idx.size() == 1) {
// after eliminate, only one output left.
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
// update users.
for (auto &ret_user : all_users) {
(void)mng->Replace(ret_user.first, node);
}
} else {
// after eliminate, create new multi output.
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
std::unordered_map<int64_t, int64_t> new_idx_map;
for (auto idx : used_output_idx) {
new_idx_map[idx] = SizeToLong(new_output_inputs.size() - 1);
new_output_inputs.push_back(output_cnode->input(idx + 1));
}
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
// update users.
for (auto &ret_user : all_users) {
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
}
}
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -39,7 +39,6 @@
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h" #include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/context/graph_kernel_flags.h"
#include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/pipeline_split.h"
#include "pipeline/jit/static_analysis/auto_monad.h" #include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/gradient_eliminate.h"
@ -296,31 +295,6 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i
return map; return map;
} }
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap map({
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
});
return map;
}
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig elim_1 = opt::OptPassConfig({
irpass.addn_eliminate_,
irpass.incorporate_getitem_from_param_,
});
opt::OptPassConfig elim_2 = opt::OptPassConfig({
irpass.unused_parameter_eliminate_,
irpass.unused_output_eliminate_,
});
OptPassGroupMap map({
{"elim_1", elim_1},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"elim_2", elim_2},
});
return map;
}
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
} }
@ -375,10 +349,6 @@ void InitOpt(const ResourcePtr &res) {
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_trans_graph"] = g_pass_opts["opt_trans_graph"] =
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true); Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_grad_epilogue"] = g_pass_opts["opt_grad_epilogue"] =
@ -386,10 +356,6 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
g_pass_opts["opt_after_recompute"] = g_pass_opts["opt_after_recompute"] =
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
}
} }
} }
} // namespace } // namespace
@ -424,8 +390,6 @@ bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a");
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); } bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); } bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
@ -559,8 +523,6 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_after_cconv", OptPassAfterCconvGroup}, {"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass}, {"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"tuple_transform", OptPassTransformGraphGroup}, {"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_cache_embedding", AddCacheEmbeddingPass}, {"add_cache_embedding", AddCacheEmbeddingPass},
{"add_recomputation", AddRecomputationPass}, {"add_recomputation", AddRecomputationPass},
{"cse_after_recomputation", OptAfterRecomputeGroup}}; {"cse_after_recomputation", OptAfterRecomputeGroup}};

View File

@ -216,9 +216,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>()); pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>()); pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>()); pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
auto context_ptr = MsContext::GetInstance(); if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all")); pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
} }
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum")); pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));

View File

@ -602,12 +602,6 @@ bool GraphPartition::IsCut(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
} }
AnfNodePtr fn = inputs[0]; AnfNodePtr fn = inputs[0];
if (IsValueNode<FuncGraph>(fn)) {
auto fg = GetValueNode<FuncGraphPtr>(fn);
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
return false;
}
}
if (!IsValueNode<Primitive>(fn)) { if (!IsValueNode<Primitive>(fn)) {
return true; return true;
} }

View File

@ -510,7 +510,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(prim_graph); MS_EXCEPTION_IF_NULL(prim_graph);
FuncGraphSet graphs = prim_graph->manager()->func_graphs(); FuncGraphSet graphs = prim_graph->manager()->func_graphs();
for (auto g : graphs) { for (auto g : graphs) {
if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { if (g != graph && g != nullptr) {
Compile(g); Compile(g);
} }
} }
@ -542,7 +542,7 @@ uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
// Compile sub graphs. // Compile sub graphs.
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs(); FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
for (auto sub_graph : sub_graphs) { for (auto sub_graph : sub_graphs) {
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { if (sub_graph != func_graph && sub_graph != nullptr) {
(void)CompileGraph(sub_graph); (void)CompileGraph(sub_graph);
} }
} }

View File

@ -23,6 +23,7 @@ import numpy
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.common._decorator import deprecated
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from .. import context from .. import context
from .._c_expression import init_pipeline, Cell_, FuncGraph from .._c_expression import init_pipeline, Cell_, FuncGraph
@ -1213,6 +1214,11 @@ class GraphKernel(Cell):
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
enable_graph_kernel in context is set to True. enable_graph_kernel in context is set to True.
This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead.
GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as
normal `Cell` objects.
Args: Args:
auto_prefix (bool): Recursively generate namespaces. Default: True. auto_prefix (bool): Recursively generate namespaces. Default: True.
flags (dict) : Set graph flags. Default: None. flags (dict) : Set graph flags. Default: None.
@ -1230,10 +1236,9 @@ class GraphKernel(Cell):
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) ... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
""" """
@deprecated("1.3", "Cell", True)
def __init__(self, auto_prefix=True, flags=None): def __init__(self, auto_prefix=True, flags=None):
super(GraphKernel, self).__init__(auto_prefix, flags) super(GraphKernel, self).__init__(auto_prefix, flags)
class_name = self.__class__.__name__
self.add_flags(graph_kernel=class_name)
def construct(self): def construct(self):
raise NotImplementedError raise NotImplementedError

View File

@ -16,7 +16,6 @@
import numpy as np import numpy as np
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -87,7 +86,7 @@ class Softmax(Cell):
def __init__(self, axis=-1): def __init__(self, axis=-1):
super(Softmax, self).__init__() super(Softmax, self).__init__()
self.softmax = _selected_ops.Softmax(axis) self.softmax = P.Softmax(axis)
def construct(self, x): def construct(self, x):
return self.softmax(x) return self.softmax(x)
@ -137,7 +136,7 @@ class LogSoftmax(Cell):
def __init__(self, axis=-1): def __init__(self, axis=-1):
super(LogSoftmax, self).__init__() super(LogSoftmax, self).__init__()
self.log_softmax = _selected_ops.LogSoftmax(axis) self.log_softmax = P.LogSoftmax(axis)
def construct(self, x): def construct(self, x):
return self.log_softmax(x) return self.log_softmax(x)
@ -368,7 +367,7 @@ class Tanh(Cell):
def __init__(self): def __init__(self):
super(Tanh, self).__init__() super(Tanh, self).__init__()
self.tanh = _selected_ops.Tanh() self.tanh = P.Tanh()
def construct(self, x): def construct(self, x):
return self.tanh(x) return self.tanh(x)
@ -415,7 +414,7 @@ class GELU(Cell):
def __init__(self): def __init__(self):
super(GELU, self).__init__() super(GELU, self).__init__()
self.gelu = _selected_ops.GeLU() self.gelu = P.GeLU()
def construct(self, x): def construct(self, x):
return self.gelu(x) return self.gelu(x)
@ -458,7 +457,7 @@ class FastGelu(Cell):
def __init__(self): def __init__(self):
super(FastGelu, self).__init__() super(FastGelu, self).__init__()
self.fast_gelu = _selected_ops.FastGeLU() self.fast_gelu = P.FastGeLU()
def construct(self, x): def construct(self, x):
return self.fast_gelu(x) return self.fast_gelu(x)

View File

@ -30,7 +30,6 @@ from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management from mindspore.communication import management
from mindspore.ops import _selected_ops
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from ..cell import Cell from ..cell import Cell
@ -837,9 +836,9 @@ class LayerNorm(Cell):
gamma_init, normalized_shape), name="gamma") gamma_init, normalized_shape), name="gamma")
self.beta = Parameter(initializer( self.beta = Parameter(initializer(
beta_init, normalized_shape), name="beta") beta_init, normalized_shape), name="beta")
self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis, self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
begin_params_axis=self.begin_params_axis, begin_params_axis=self.begin_params_axis,
epsilon=self.epsilon) epsilon=self.epsilon)
def construct(self, input_x): def construct(self, input_x):
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)

View File

@ -21,7 +21,6 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import nn from mindspore import nn
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.ops import _selected_ops
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation from mindspore.nn.layer.activation import get_activation
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
@ -48,7 +47,7 @@ class _Loss(Cell):
if reduction == 'none': if reduction == 'none':
self.reduce = False self.reduce = False
self.reduce_mean = _selected_ops.ReduceMean() self.reduce_mean = P.ReduceMean()
self.reduce_sum = P.ReduceSum() self.reduce_sum = P.ReduceSum()
self.mul = P.Mul() self.mul = P.Mul()
self.cast = P.Cast() self.cast = P.Cast()
@ -381,7 +380,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = validator.check_bool(sparse, "sparse") self.sparse = validator.check_bool(sparse, "sparse")
self.reduction = reduction self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits() self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot() self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32) self.off_value = Tensor(0., mstype.float32)

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""momentum""" """momentum"""
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops import _selected_ops
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@ -157,7 +156,7 @@ class Momentum(Optimizer):
self.use_nesterov = Validator.check_bool(use_nesterov) self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov) self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
def construct(self, gradients): def construct(self, gradients):
params = self.params params = self.params

View File

@ -19,7 +19,6 @@ from functools import reduce
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import nn from mindspore import nn
from mindspore.ops import _selected_grad_ops as SG
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
from ..operations import _grad_ops as G from ..operations import _grad_ops as G
@ -590,7 +589,7 @@ def get_bprop_expm1(self):
@bprop_getters.register(P.Minimum) @bprop_getters.register(P.Minimum)
def get_bprop_minimum(self): def get_bprop_minimum(self):
"""Grad definition for `Minimum` operation.""" """Grad definition for `Minimum` operation."""
input_grad = SG.MinimumGrad() input_grad = G.MinimumGrad()
def bprop(x, y, out, dout): def bprop(x, y, out, dout):
dx, dy = input_grad(x, y, dout) dx, dy = input_grad(x, y, dout)
@ -602,7 +601,7 @@ def get_bprop_minimum(self):
@bprop_getters.register(P.Maximum) @bprop_getters.register(P.Maximum)
def get_bprop_maximum(self): def get_bprop_maximum(self):
"""Grad definition for `Maximum` operation.""" """Grad definition for `Maximum` operation."""
input_grad = SG.MaximumGrad() input_grad = G.MaximumGrad()
def bprop(x, y, out, dout): def bprop(x, y, out, dout):
dx, dy = input_grad(x, y, dout) dx, dy = input_grad(x, y, dout)
@ -1107,7 +1106,7 @@ def get_bprop_cosh(self):
@bprop_getters.register(P.Abs) @bprop_getters.register(P.Abs)
def get_bprop_abs(self): def get_bprop_abs(self):
"""Grad definition for `Abs` operation.""" """Grad definition for `Abs` operation."""
abs_grad = SG.AbsGrad() abs_grad = G.AbsGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = abs_grad(x, dout) dx = abs_grad(x, dout)

View File

@ -15,7 +15,6 @@
"""Define the grad rules of neural network related operations.""" """Define the grad rules of neural network related operations."""
import os import os
from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops.operations import nn_ops as nps from mindspore.ops.operations import nn_ops as nps
@ -34,7 +33,7 @@ env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
@bprop_getters.register(P.BiasAdd) @bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self): def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation.""" """Grad definition for `BiasAdd` operation."""
bias_grad = SG.BiasAddGrad(self.data_format) bias_grad = G.BiasAddGrad(self.data_format)
def bprop(x, w, out, dout): def bprop(x, w, out, dout):
return dout, bias_grad(dout) return dout, bias_grad(dout)
@ -341,7 +340,7 @@ def get_bprop_dropout_do_mask(self):
def get_bprop_mish(self): def get_bprop_mish(self):
"""Grad definition for `Mish` operation.""" """Grad definition for `Mish` operation."""
tanh = P.Tanh() tanh = P.Tanh()
tanh_grad = SG.TanhGrad() tanh_grad = G.TanhGrad()
softplus = P.Softplus() softplus = P.Softplus()
softplus_grad = G.SoftplusGrad() softplus_grad = G.SoftplusGrad()
@ -580,7 +579,7 @@ def get_bprop_softsign(self):
@bprop_getters.register(P.Tanh) @bprop_getters.register(P.Tanh)
def get_bprop_tanh(self): def get_bprop_tanh(self):
"""Grad definition for `Tanh` operation.""" """Grad definition for `Tanh` operation."""
tanh_grad = SG.TanhGrad() tanh_grad = G.TanhGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = tanh_grad(out, dout) dx = tanh_grad(out, dout)

View File

@ -1,50 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" resolved grad ops """
from mindspore.ops.op_selector import new_ops_selector
op_selector = new_ops_selector(
"mindspore.ops.operations._grad_ops", "mindspore.nn._graph_kernels")
@op_selector
class MaximumGrad:
def __call__(self, *args):
pass
@op_selector
class MinimumGrad:
def __call__(self, *args):
pass
@op_selector
class AbsGrad:
def __call__(self, *args):
pass
@op_selector
class BiasAddGrad:
def __call__(self, *args):
pass
@op_selector
class TanhGrad:
def __call__(self, *args):
pass

View File

@ -1,120 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" resolve ops """
from mindspore.ops.op_selector import new_ops_selector
op_selector = new_ops_selector(
"mindspore.ops.operations", "mindspore.nn._graph_kernels")
opt_selector = new_ops_selector(
"mindspore.nn.optim", "mindspore.nn._graph_kernels")
nn_selector = new_ops_selector(
"mindspore.nn", "mindspore.nn._graph_kernels")
@nn_selector
class BatchNorm2d:
def __call__(self, *args):
pass
@op_selector
class ReLU:
def __call__(self, *args):
pass
@op_selector
class ReduceMean:
def __call__(self, *args):
pass
@op_selector
class BiasAdd:
def __call__(self, *args):
pass
@op_selector
class ApplyMomentum:
def __call__(self, *args):
pass
@op_selector
class SoftmaxCrossEntropyWithLogits:
def __call__(self, *args):
pass
@op_selector
class LogSoftmax:
def __call__(self, *args):
pass
@op_selector
class Tanh:
def __call__(self, *args):
pass
@op_selector
class GeLU:
def __call__(self, *args):
pass
@op_selector
class FastGeLU:
def __call__(self, *args):
pass
@op_selector
class LayerNorm:
def __call__(self, *args):
pass
@op_selector
class Softmax:
def __call__(self, *args):
pass
@op_selector
class LambUpdateWithLR:
def __call__(self, *args):
pass
@op_selector
class LambNextMV:
def __call__(self, *args):
pass
@op_selector
class LambApplyOptimizerAssign:
def __call__(self, *args):
pass
@op_selector
class LambApplyWeightAssign:
def __call__(self, *args):
pass

View File

@ -20,6 +20,7 @@ which can be used to control the switch of op type: GraphKernel or Primitive.
import importlib import importlib
import inspect import inspect
from mindspore import context from mindspore import context
from mindspore.common._decorator import deprecated
class _OpSelector: class _OpSelector:
@ -70,6 +71,7 @@ class _OpSelector:
return op(*args, **kwargs) return op(*args, **kwargs)
@deprecated("1.3", "basic Primitive", False)
def new_ops_selector(primitive_pkg, graph_kernel_pkg): def new_ops_selector(primitive_pkg, graph_kernel_pkg):
""" """
A factory method to return an op selector A factory method to return an op selector
@ -83,6 +85,8 @@ def new_ops_selector(primitive_pkg, graph_kernel_pkg):
The order of the highest priority to lowest priority is (1), (2), (3) The order of the highest priority to lowest priority is (1), (2), (3)
If the GraphKernel switch is off, then op_type will always be PRIMITIVE. If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
The user-defined GraphKernel Cell is deprecated, this interface will be removed in a future version.
Args: Args:
primitive_pkg (str): primitive op's package name primitive_pkg (str): primitive op's package name
graph_kernel_pkg (str): graph kernel op's package name graph_kernel_pkg (str): graph kernel op's package name

View File

@ -16,6 +16,7 @@
"""Other operators.""" """Other operators."""
import functools import functools
from mindspore.common import monad from mindspore.common import monad
from mindspore.common._decorator import deprecated
from .. import signature as sig from .. import signature as sig
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype from ...common import dtype as mstype
@ -84,6 +85,8 @@ class InplaceAssign(PrimitiveWithInfer):
Inplace assign `Parameter` with a value. Inplace assign `Parameter` with a value.
This primitive can only use in graph kernel. This primitive can only use in graph kernel.
InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead.
Inputs: Inputs:
- **variable** (Parameter) - The `Parameter`. - **variable** (Parameter) - The `Parameter`.
- **value** (Tensor) - The value to be assigned. - **value** (Tensor) - The value to be assigned.
@ -110,7 +113,8 @@ class InplaceAssign(PrimitiveWithInfer):
>>> net = Net() >>> net = Net()
>>> output = net(x) >>> output = net(x)
>>> print(output) >>> print(output)
""" """
@deprecated("1.3", "Assign", False)
@ prim_attr_register @ prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])

View File

@ -26,7 +26,6 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import _selected_ops
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model from mindspore.train import Model
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
@ -76,7 +75,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
self.sparse = sparse self.sparse = sparse
self.reduction = reduction self.reduction = reduction
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits() self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.one_hot = P.OneHot() self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0., mstype.float32) self.off_value = Tensor(0., mstype.float32)