forked from mindspore-Ecosystem/mindspore
!15172 Clean GraphKernel's codes from frontend
From: @dayschan Reviewed-by: @gaoxiong1,@dylangeng,@gaoxiong1 Signed-off-by: @dylangeng,@dylangeng
This commit is contained in:
commit
0fd1726e79
|
@ -17,6 +17,7 @@
|
|||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||
|
@ -110,8 +111,6 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/splitv_fission.h"
|
||||
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
|
||||
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
||||
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||
|
@ -199,19 +198,6 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||
}
|
||||
} // namespace
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
MS_EXCEPTION_IF_NULL(optimizer);
|
||||
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
|
||||
MS_EXCEPTION_IF_NULL(common_process);
|
||||
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
|
||||
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
|
||||
optimizer->AddPassManager(common_process);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||
|
|
|
@ -24,7 +24,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
} // namespace opt
|
||||
|
|
|
@ -344,7 +344,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
|
||||
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
|
||||
auto real_input_node = kernel_with_index.first;
|
||||
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
if (kernel::IsWeightBoundary(real_input_node)) {
|
||||
// weight
|
||||
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
||||
if (origin_type == kTypeUnknown) {
|
||||
|
@ -358,9 +358,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
||||
// In graph kernel, we check parameter,
|
||||
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
||||
new_inputs.push_back(cur_input);
|
||||
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||
auto cast =
|
||||
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
|
|
|
@ -78,21 +78,12 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> todos = {node};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
||||
}
|
||||
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < in_num; ++i) {
|
||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
||||
<< cnode->DebugString() << "]";
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < in_num; ++i) {
|
||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
||||
<< cnode->DebugString() << "]";
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
|
|
@ -30,8 +30,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
|
@ -42,32 +41,23 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
|||
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
||||
AnfNodePtr replace_node = nullptr;
|
||||
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
const auto origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||
auto idx = NewValueNode(SizeToLong(output_idx));
|
||||
MS_EXCEPTION_IF_NULL(idx);
|
||||
auto imm = std::make_shared<Int64Imm>(output_idx);
|
||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
|
||||
if (need_insert_cast[output_idx]) {
|
||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
||||
}
|
||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||
if (origin_type != device_type) {
|
||||
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
|
||||
infer_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
||||
}
|
||||
} else {
|
||||
replace_node = getitem;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get());
|
||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||
if (origin_type != device_type) {
|
||||
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
|
||||
origin_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
||||
}
|
||||
} else {
|
||||
replace_node = getitem;
|
||||
|
@ -81,8 +71,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
|||
return make_tuple;
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<bool> &need_insert_cast) {
|
||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
||||
|
@ -92,23 +81,14 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
// Single output
|
||||
if (!cnode->Type()->isa<Tuple>()) {
|
||||
if (!need_insert_cast[0]) {
|
||||
return cnode;
|
||||
}
|
||||
|
||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
TypeId origin_type(kTypeUnknown);
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
||||
}
|
||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
||||
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
||||
AnfNodePtr replace_node = cnode;
|
||||
if (origin_type != device_type) {
|
||||
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
|
||||
infer_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
|
||||
origin_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
replace_node->set_scope(cnode->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||
|
@ -119,69 +99,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
return replace_node;
|
||||
}
|
||||
// Multiple output
|
||||
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
// insert cast for ops in graph kernel.
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
||||
for (auto &output : outputs) {
|
||||
size_t index = 0;
|
||||
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
||||
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
|
||||
MS_EXCEPTION_IF_NULL(tuple_index_value);
|
||||
if (!tuple_index_value->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
|
||||
}
|
||||
index = tuple_index_value->cast<Int64ImmPtr>()->value();
|
||||
}
|
||||
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
|
||||
}
|
||||
for (auto &t : todo) {
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
|
||||
// process input
|
||||
CNodePtr t_cnode = t->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(t_cnode);
|
||||
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
|
||||
AnfNodePtr t_new_node_1 = nullptr;
|
||||
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
|
||||
// process output
|
||||
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
|
||||
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
|
||||
if (iter != graph_rets.end()) {
|
||||
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
|
||||
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
|
||||
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
|
||||
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
|
||||
need_insert_cast[iter->second] = false;
|
||||
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
|
||||
need_insert_cast[iter->second] = false;
|
||||
}
|
||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
||||
} else {
|
||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
||||
}
|
||||
|
||||
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
|
||||
(void)mng->Replace(t, t_new_node_1);
|
||||
}
|
||||
}
|
||||
|
||||
// insert cast for graph kernel.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
return InsertCastForMultipleOutput(func_graph, cnode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -196,11 +114,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return ProcessGraphKernelOp(func_graph, node);
|
||||
}
|
||||
// insert cast for single op.
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
// process input
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
|
@ -211,7 +124,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
|||
kernel_graph->ReplaceInternalOutput(node, new_node);
|
||||
}
|
||||
// process output
|
||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
||||
return InsertCastForOutput(func_graph, new_node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -150,9 +150,6 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
return nullptr;
|
||||
}
|
||||
auto next_cnode = next_node->cast<CNodePtr>();
|
||||
if (AnfAlgo::IsGraphKernel(next_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
|
||||
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
|
||||
return nullptr;
|
||||
|
@ -224,9 +221,6 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
|||
return nullptr;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prior_op);
|
||||
if (AnfAlgo::IsGraphKernel(prior_op)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
|
||||
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
|
||||
if (input_shape.size() != 5) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrMultiples);
|
||||
if (multiples.size() == 4 && multiples[1] == 1) {
|
||||
multiples.push_back(1);
|
||||
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
|
||||
}
|
||||
|
||||
return cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name == prim::kPrimTile->name()) {
|
||||
return ModifyTileOpAttrs(cnode);
|
||||
} else if (op_name == prim::kPrimReduceSum->name()) {
|
||||
// kPrimReduceMean
|
||||
// kPrimReduceSum
|
||||
// kPrimReduceAll
|
||||
// kPrimReduceMax
|
||||
// kPrimReduceMin
|
||||
return ModifyReduceOpsAttrs(cnode);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(fg, &todos);
|
||||
for (auto &t : todos) {
|
||||
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
|
||||
if (new_node != nullptr && new_node != t) {
|
||||
(void)manager->Replace(t, new_node);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,33 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ModifyOpAttrs : public PatternProcessPass {
|
||||
public:
|
||||
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
|
||||
~ModifyOpAttrs() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
|
@ -1,66 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "base/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (op_name != prim::kPrimReshape->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
|
||||
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return cnode->input(1);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(fg, &todos);
|
||||
for (auto &t : todos) {
|
||||
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
|
||||
if (new_node != nullptr && new_node != t) {
|
||||
(void)manager->Replace(t, new_node);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,33 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class RemoveNoUseReshapeOp : public PatternProcessPass {
|
||||
public:
|
||||
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
|
||||
~RemoveNoUseReshapeOp() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
|
@ -122,9 +122,6 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
|
|||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<CNodePtr> cast_nodes;
|
||||
|
|
|
@ -596,20 +596,21 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
|
||||
prim::KPrimTransData};
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
||||
prim::kPrimLess};
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
||||
prim::kPrimLess, prim::kPrimInplaceAssign};
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
|
|
|
@ -81,13 +81,9 @@ const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const
|
|||
if (iter == MarkOp.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AddAttrTraining(func_graph, cnode);
|
||||
return cnode;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AddAttrTraining(func_graph, cnode);
|
||||
return cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
@ -29,32 +28,22 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
|
|||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> todos;
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
||||
} else {
|
||||
todos.push_back(node);
|
||||
}
|
||||
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
AnfNodePtr op = inputs[0];
|
||||
if (IsValueNode<Primitive>(op)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
auto attrs = prim->attrs();
|
||||
std::string type_name = prim->name();
|
||||
for (auto attr : attrs) {
|
||||
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
|
||||
if (converted) {
|
||||
prim->set_attr(attr.first, attr.second);
|
||||
}
|
||||
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
|
||||
if (converted_ir_attr) {
|
||||
prim->set_attr(attr.first, attr.second);
|
||||
}
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
AnfNodePtr op = inputs[0];
|
||||
if (IsValueNode<Primitive>(op)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
auto attrs = prim->attrs();
|
||||
std::string type_name = prim->name();
|
||||
for (auto attr : attrs) {
|
||||
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
|
||||
if (converted) {
|
||||
prim->set_attr(attr.first, attr.second);
|
||||
}
|
||||
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
|
||||
if (converted_ir_attr) {
|
||||
prim->set_attr(attr.first, attr.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
#include "backend/optimizer/pass/convert_const_input_to_attr.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
@ -34,40 +33,31 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> todos;
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
||||
} else {
|
||||
todos.push_back(node);
|
||||
}
|
||||
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::IsDynamicShape(cnode) &&
|
||||
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
if (AnfAlgo::IsDynamicShape(cnode) &&
|
||||
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
|
||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -100,24 +100,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
|
||||
for (auto &t : todo) {
|
||||
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());
|
||||
if (t_new_node != nullptr && t_new_node != t) {
|
||||
(void)mng->Replace(t, t_new_node);
|
||||
}
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -129,11 +111,8 @@ const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &fun
|
|||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return ProcessGraphKernelOp(node);
|
||||
} else {
|
||||
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -95,15 +95,6 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
|
|||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
||||
for (auto &t : todos) {
|
||||
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
||||
return node;
|
||||
}
|
||||
|
|
|
@ -170,26 +170,6 @@ const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, c
|
|||
if (cnode == nullptr || func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
// do eliminate for ops in graph kernel.
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
auto mng = sub_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
std::vector<AnfNodePtr> todo;
|
||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
||||
for (auto &t : todo) {
|
||||
CNodePtr t_cnode = t->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(t_cnode);
|
||||
auto t_new_node = DoEliminate(sub_graph, t_cnode);
|
||||
if (t_new_node != nullptr && t_new_node != t) {
|
||||
(void)mng->Replace(t, t_new_node);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
}
|
||||
// do eliminate for single op.
|
||||
return DoEliminate(func_graph, cnode);
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -30,20 +30,7 @@ const BaseRef EraseVisitAttr::DefinePattern() const {
|
|||
}
|
||||
|
||||
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::vector<AnfNodePtr> todos;
|
||||
kernel::GetValidKernelNodes(fg, &todos);
|
||||
for (auto &t : todos) {
|
||||
AnfAlgo::EraseNodeAttr(kAttrVisited, t);
|
||||
}
|
||||
}
|
||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||
} else {
|
||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||
}
|
||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -949,7 +949,6 @@ void AscendSession::InitRuntimeResource() {
|
|||
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "HardwareOptimize start!";
|
||||
opt::AscendBackendOptimization(kernel_graph);
|
||||
opt::AscendGraphKernelCommonProcess(kernel_graph);
|
||||
GraphKernelOptimize(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->SetExecOrderByDefault();
|
||||
|
|
|
@ -433,7 +433,9 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
auto cnode = FuncGraph::NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||
CreateKernelInfoFromNewParameter(cnode);
|
||||
if (AnfAlgo::IsGraphKernel(cnode)) {
|
||||
CreateKernelInfoFromNewParameter(cnode);
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
|
@ -443,9 +445,6 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
}
|
||||
|
||||
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
||||
if (!AnfAlgo::IsGraphKernel(cnode)) {
|
||||
return;
|
||||
}
|
||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
|
|
|
@ -45,10 +45,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
|
||||
k_graph_ = std::make_shared<FuncGraph>();
|
||||
}
|
||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||
}
|
||||
// To keep switch_layer's inputs from being inlined
|
||||
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
||||
k_graph_->set_stage(primal_graph->stage());
|
||||
|
@ -58,11 +54,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
tape_ = std::make_shared<FuncGraph>();
|
||||
}
|
||||
tape_->set_stage(primal_graph->stage());
|
||||
// Add "_Grad" postfix
|
||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
|
||||
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||
}
|
||||
|
||||
dout_ = tape_->add_parameter();
|
||||
}
|
||||
|
|
|
@ -194,8 +194,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Incorporation
|
||||
incorporate_getitem_set_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
||||
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
|
||||
"incorporate_getitem_from_param", IsCNodeGraphKernel);
|
||||
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
|
||||
incorporate_call_switch_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
|
||||
|
@ -211,19 +209,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
print_tuple_wrapper_ =
|
||||
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
|
||||
|
||||
// Unused parameter eliminate
|
||||
unused_parameter_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
|
||||
unused_output_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// tuple parameter graph transform
|
||||
call_graph_tuple_transform_ =
|
||||
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
|
||||
|
||||
// AddN eliminate
|
||||
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
|
||||
|
||||
// RowTensor Eliminate
|
||||
row_tensor_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
|
||||
|
|
|
@ -112,7 +112,6 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Incorporation
|
||||
SubstitutionPtr incorporate_getitem_set_;
|
||||
SubstitutionPtr incorporate_getitem_from_param_;
|
||||
SubstitutionPtr incorporate_call_;
|
||||
SubstitutionPtr incorporate_call_switch_;
|
||||
|
||||
|
@ -125,16 +124,9 @@ class OptimizeIRPassLib {
|
|||
// Convert
|
||||
SubstitutionPtr print_tuple_wrapper_;
|
||||
|
||||
// Unused parameter eliminate
|
||||
SubstitutionPtr unused_parameter_eliminate_;
|
||||
SubstitutionPtr unused_output_eliminate_;
|
||||
|
||||
// tuple parameter graph transform
|
||||
SubstitutionPtr call_graph_tuple_transform_;
|
||||
|
||||
// AddN eliminate
|
||||
SubstitutionPtr addn_eliminate_;
|
||||
|
||||
// RowTensor Eliminate
|
||||
SubstitutionPtr row_tensor_eliminate_;
|
||||
|
||||
|
@ -214,23 +206,6 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
|||
return IsValueNode<FuncGraph>(inp0);
|
||||
}
|
||||
|
||||
// Check if CNode Input 0 is Func Graph of graph kernel.
|
||||
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inp0 = node->cast<CNodePtr>()->input(0);
|
||||
if (IsValueNode<FuncGraph>(inp0)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inp0);
|
||||
if (fg == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if CNode Input 0 is CNode
|
||||
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
|
@ -135,25 +134,6 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// If graph kernel has muti output, do not split.
|
||||
// some graph kernel output has EnvInstance node or DeadCode node should split.
|
||||
auto output = fg_->output();
|
||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output->cast<CNodePtr>();
|
||||
auto outputs = output_cnode->inputs();
|
||||
int64_t real_output_cnt = 0;
|
||||
for (size_t i = 1; i < outputs.size(); ++i) {
|
||||
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
|
||||
real_output_cnt++;
|
||||
if (real_output_cnt > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto new_fg = getitem_transform_(fg_, idx_);
|
||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||
return node->func_graph()->NewCNode(args_);
|
||||
|
@ -184,171 +164,6 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
internal::GetitemTransform getitem_transform_;
|
||||
};
|
||||
|
||||
class IncorporateGetitemFromParam : public AnfVisitor {
|
||||
public:
|
||||
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &node_users = mng->node_users();
|
||||
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
|
||||
args_.push_back(cnode->input(input_idx + 1));
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &user : node_users[param]) {
|
||||
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
|
||||
// we do not process this case.
|
||||
args_.push_back(cnode->input(input_idx + 1));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// update new args.
|
||||
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
|
||||
// case 1
|
||||
replace_parameters_[input_idx] = true;
|
||||
need_update_ = true;
|
||||
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
||||
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
|
||||
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
|
||||
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
|
||||
} else {
|
||||
// case 2
|
||||
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
||||
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
|
||||
auto fg_output = prev_fg->output();
|
||||
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
|
||||
<< " should be a make tuple, but got: " << fg_output->DebugString();
|
||||
return;
|
||||
}
|
||||
replace_parameters_[input_idx] = true;
|
||||
need_update_ = true;
|
||||
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
|
||||
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
|
||||
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
|
||||
auto new_getitem =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToLong(output_i))});
|
||||
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(SizeToLong(output_i)));
|
||||
new_getitem->input(2)->set_abstract(aptr);
|
||||
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
|
||||
args_.push_back(new_getitem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Reset();
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
if (fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto parameters = fg->parameters();
|
||||
if (parameters.size() != inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
replace_parameters_ = std::vector<bool>(parameters.size(), false);
|
||||
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
|
||||
auto node_fg = node->func_graph();
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
|
||||
Process(node_fg, cnode, parameters[i - 1], i - 1);
|
||||
} else {
|
||||
args_.push_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!need_update_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
|
||||
auto node_users = mng->node_users();
|
||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
size_t curr_input_idx{0};
|
||||
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
|
||||
if (!replace_parameters_[param_i]) {
|
||||
if (parameters[param_i]->abstract() != nullptr) {
|
||||
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
|
||||
}
|
||||
new_parameters.push_back(new_fg_parameters[param_i]);
|
||||
curr_input_idx++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// make a new parameter.
|
||||
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
|
||||
auto new_param = std::make_shared<Parameter>(new_fg);
|
||||
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
|
||||
|
||||
// update users of new parameter.
|
||||
for (auto &user : node_users[new_fg_parameters[param_i]]) {
|
||||
idx_ = -1;
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int64Imm>})(user.first);
|
||||
if (idx_ == -1) {
|
||||
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
|
||||
<< " must be tuple getitem here, but got: " << user.first->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (input_i == LongToSize(idx_)) {
|
||||
for (auto &sub_user : node_users[user.first]) {
|
||||
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sub_user_cnode);
|
||||
sub_user_cnode->set_input(sub_user.second, new_param);
|
||||
(void)mng->Replace(sub_user.first, sub_user_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_parameters.push_back(new_param);
|
||||
curr_input_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
mng->SetParameters(new_fg, new_parameters);
|
||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||
auto new_call = node_fg->NewCNode(args_);
|
||||
new_call->set_abstract(node->abstract());
|
||||
return new_call;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); }
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {}
|
||||
|
||||
void Reset() {
|
||||
replace_parameters_.clear();
|
||||
args_.clear();
|
||||
inputs_num_.clear();
|
||||
need_update_ = false;
|
||||
idx_ = -1;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<bool> replace_parameters_{};
|
||||
std::vector<AnfNodePtr> args_{};
|
||||
std::vector<size_t> inputs_num_{};
|
||||
bool need_update_{false};
|
||||
int64_t idx_{-1};
|
||||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
|
||||
class IncorporateGetitemSwitch : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -103,13 +103,6 @@ class InlinerBase : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Do not inline GraphKernel to Cell.
|
||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// If the GraphKernel only contains a return node, we make it inlined.
|
||||
if (fg->nodes().size() - fg->parameters().size() > 1) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
Reset();
|
||||
|
||||
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
|
||||
|
|
|
@ -205,130 +205,6 @@ class AddNZeroFilter : public AnfVisitor {
|
|||
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
||||
bool has_zero_like_{false};
|
||||
};
|
||||
|
||||
// {PrimAddN, {kPrimMakeTuple, Xs}}
|
||||
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
|
||||
// case0: AddN(inputs)(inputs size < 2) -> error
|
||||
// case1: AddN(inputs)(all inputs is ValueNode) -> error
|
||||
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
|
||||
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
|
||||
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
||||
class AddNEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
if (fg->recursive()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
need_update_ = false;
|
||||
bool changed;
|
||||
do {
|
||||
changed = Process(new_fg);
|
||||
} while (changed);
|
||||
|
||||
if (!need_update_) {
|
||||
return nullptr;
|
||||
} else {
|
||||
auto new_sx = inputs;
|
||||
new_sx[0] = NewValueNode(new_fg);
|
||||
return node->func_graph()->NewCNode(new_sx);
|
||||
}
|
||||
}
|
||||
|
||||
bool Process(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto nodes = TopoSort(func_graph->output());
|
||||
bool changed = false;
|
||||
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
auto node = nodes[i];
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &tuple_input = cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
|
||||
auto &tuple_inputs = tuple_input_cnode->inputs();
|
||||
if (tuple_inputs.size() < 3) {
|
||||
// case0: inputs size < 2, error
|
||||
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
|
||||
}
|
||||
|
||||
int64_t valuenode_num = std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0,
|
||||
[](int64_t accumulator, const AnfNodePtr &node) {
|
||||
if (IsValueNode<tensor::Tensor>(node)) {
|
||||
return accumulator + 1;
|
||||
} else {
|
||||
return accumulator;
|
||||
}
|
||||
});
|
||||
if (LongToSize(valuenode_num) == tuple_inputs.size()) {
|
||||
// case1: all inputs is ValueNode, error
|
||||
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
|
||||
}
|
||||
|
||||
if (tuple_inputs.size() == 3) {
|
||||
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
|
||||
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
|
||||
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
|
||||
tuple_inputs[2]};
|
||||
mng->Replace(node, func_graph->NewCNode(new_xs));
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
||||
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
|
||||
if (first_valuenode == tuple_inputs.end()) {
|
||||
// no ValueNode input found.
|
||||
continue;
|
||||
} else {
|
||||
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
||||
std::vector<AnfNodePtr> make_tuple_new_xs{
|
||||
NewValueNode(prim::kPrimMakeTuple),
|
||||
};
|
||||
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
||||
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
|
||||
if (node != *first_valuenode) {
|
||||
make_tuple_new_xs.push_back(node);
|
||||
}
|
||||
});
|
||||
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
|
||||
auto new_addn = func_graph->NewCNode(
|
||||
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
|
||||
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
||||
auto new_add =
|
||||
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
|
||||
(void)mng->Replace(node, new_add);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
need_update_ = need_update_ || changed;
|
||||
return changed;
|
||||
}
|
||||
|
||||
private:
|
||||
bool need_update_{false};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <tuple>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
|
@ -162,146 +161,6 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
|||
private:
|
||||
internal::SpecializeTransform specialize_transform_;
|
||||
};
|
||||
|
||||
// Eliminate unused parameters.
|
||||
// {G, Xs}
|
||||
class UnusedParasEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
||||
std::vector<AnfNodePtr> parameters = fg->parameters();
|
||||
size_t size = parameters.size();
|
||||
if (size != inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_xs;
|
||||
std::vector<bool> keep_parameters;
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &node_users = mng->node_users();
|
||||
bool has_unused_para = false;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto iter = node_users.find(parameters[i]);
|
||||
if (iter != node_users.end() && !iter->second.empty()) {
|
||||
keep_parameters.push_back(true);
|
||||
new_xs.push_back(inputs[i + 1]);
|
||||
continue;
|
||||
}
|
||||
keep_parameters.push_back(false);
|
||||
has_unused_para = true;
|
||||
}
|
||||
|
||||
if (!has_unused_para) {
|
||||
return nullptr;
|
||||
}
|
||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
|
||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
if (keep_parameters[i]) {
|
||||
if (parameters[i]->abstract() != nullptr) {
|
||||
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
|
||||
}
|
||||
new_parameters.push_back(new_fg_parameters[i]);
|
||||
}
|
||||
}
|
||||
mng->SetParameters(new_fg, new_parameters);
|
||||
|
||||
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
||||
return node->func_graph()->NewCNode(new_xs);
|
||||
}
|
||||
};
|
||||
|
||||
// Eliminate unused outputs.
|
||||
// {G, Xs}
|
||||
class UnusedOutputEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
if (fg->recursive()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
||||
mng->AddFuncGraph(new_fg);
|
||||
auto new_fg_output = new_fg->output();
|
||||
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto output_cnode = new_fg_output->cast<CNodePtr>();
|
||||
auto &node_users = mng->node_users();
|
||||
if (node_users.count(node) == 0 || node_users[node].empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
std::unordered_set<int64_t> used_output_idx;
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> all_users;
|
||||
for (auto &node_user : node_users[node]) {
|
||||
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto user_cnode = node_user.first->cast<CNodePtr>();
|
||||
size_t used_idx = GetValue<int64_t>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
|
||||
used_output_idx.insert(used_idx);
|
||||
all_users.push_back(std::make_pair(node_user.first, used_idx));
|
||||
}
|
||||
|
||||
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
|
||||
// all output has users.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (used_output_idx.empty()) {
|
||||
// we do not process this case.
|
||||
return nullptr;
|
||||
} else if (used_output_idx.size() == 1) {
|
||||
// after eliminate, only one output left.
|
||||
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
|
||||
// update users.
|
||||
for (auto &ret_user : all_users) {
|
||||
(void)mng->Replace(ret_user.first, node);
|
||||
}
|
||||
} else {
|
||||
// after eliminate, create new multi output.
|
||||
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
|
||||
std::unordered_map<int64_t, int64_t> new_idx_map;
|
||||
for (auto idx : used_output_idx) {
|
||||
new_idx_map[idx] = SizeToLong(new_output_inputs.size() - 1);
|
||||
new_output_inputs.push_back(output_cnode->input(idx + 1));
|
||||
}
|
||||
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
|
||||
// update users.
|
||||
for (auto &ret_user : all_users) {
|
||||
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
|
||||
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
|
||||
}
|
||||
}
|
||||
|
||||
auto new_sx = inputs;
|
||||
new_sx[0] = NewValueNode(new_fg);
|
||||
return node->func_graph()->NewCNode(new_sx);
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,7 +39,6 @@
|
|||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
|
@ -296,31 +295,6 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i
|
|||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
OptPassGroupMap map({
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig elim_1 = opt::OptPassConfig({
|
||||
irpass.addn_eliminate_,
|
||||
irpass.incorporate_getitem_from_param_,
|
||||
});
|
||||
opt::OptPassConfig elim_2 = opt::OptPassConfig({
|
||||
irpass.unused_parameter_eliminate_,
|
||||
irpass.unused_output_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"elim_1", elim_1},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"elim_2", elim_2},
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
}
|
||||
|
@ -375,10 +349,6 @@ void InitOpt(const ResourcePtr &res) {
|
|||
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
|
||||
g_pass_opts["opt_trans_graph"] =
|
||||
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
|
||||
g_pass_opts["opt_graph_kernel_a"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
|
||||
g_pass_opts["opt_graph_kernel_b"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
|
||||
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
|
||||
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
||||
g_pass_opts["opt_grad_epilogue"] =
|
||||
|
@ -386,10 +356,6 @@ void InitOpt(const ResourcePtr &res) {
|
|||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
g_pass_opts["opt_after_recompute"] =
|
||||
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
|
||||
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -424,8 +390,6 @@ bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a");
|
|||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
||||
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
||||
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
|
||||
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
||||
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
||||
|
@ -559,8 +523,6 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"opt_after_cconv", OptPassAfterCconvGroup},
|
||||
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
||||
{"tuple_transform", OptPassTransformGraphGroup},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||
{"add_recomputation", AddRecomputationPass},
|
||||
{"cse_after_recomputation", OptAfterRecomputeGroup}};
|
||||
|
|
|
@ -216,9 +216,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
|
|||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||
|
|
|
@ -602,12 +602,6 @@ bool GraphPartition::IsCut(const AnfNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
}
|
||||
AnfNodePtr fn = inputs[0];
|
||||
if (IsValueNode<FuncGraph>(fn)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(fn);
|
||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!IsValueNode<Primitive>(fn)) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -510,7 +510,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
|||
MS_EXCEPTION_IF_NULL(prim_graph);
|
||||
FuncGraphSet graphs = prim_graph->manager()->func_graphs();
|
||||
for (auto g : graphs) {
|
||||
if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
||||
if (g != graph && g != nullptr) {
|
||||
Compile(g);
|
||||
}
|
||||
}
|
||||
|
@ -542,7 +542,7 @@ uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
|
|||
// Compile sub graphs.
|
||||
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
||||
if (sub_graph != func_graph && sub_graph != nullptr) {
|
||||
(void)CompileGraph(sub_graph);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import numpy
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||
from mindspore.common._decorator import deprecated
|
||||
from mindspore.context import ParallelMode
|
||||
from .. import context
|
||||
from .._c_expression import init_pipeline, Cell_, FuncGraph
|
||||
|
@ -1213,6 +1214,11 @@ class GraphKernel(Cell):
|
|||
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
|
||||
enable_graph_kernel in context is set to True.
|
||||
|
||||
This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead.
|
||||
|
||||
GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as
|
||||
normal `Cell` objects.
|
||||
|
||||
Args:
|
||||
auto_prefix (bool): Recursively generate namespaces. Default: True.
|
||||
flags (dict) : Set graph flags. Default: None.
|
||||
|
@ -1230,10 +1236,9 @@ class GraphKernel(Cell):
|
|||
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
||||
"""
|
||||
|
||||
@deprecated("1.3", "Cell", True)
|
||||
def __init__(self, auto_prefix=True, flags=None):
|
||||
super(GraphKernel, self).__init__(auto_prefix, flags)
|
||||
class_name = self.__class__.__name__
|
||||
self.add_flags(graph_kernel=class_name)
|
||||
|
||||
def construct(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import numpy as np
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -87,7 +86,7 @@ class Softmax(Cell):
|
|||
|
||||
def __init__(self, axis=-1):
|
||||
super(Softmax, self).__init__()
|
||||
self.softmax = _selected_ops.Softmax(axis)
|
||||
self.softmax = P.Softmax(axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.softmax(x)
|
||||
|
@ -137,7 +136,7 @@ class LogSoftmax(Cell):
|
|||
|
||||
def __init__(self, axis=-1):
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.log_softmax = _selected_ops.LogSoftmax(axis)
|
||||
self.log_softmax = P.LogSoftmax(axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.log_softmax(x)
|
||||
|
@ -368,7 +367,7 @@ class Tanh(Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(Tanh, self).__init__()
|
||||
self.tanh = _selected_ops.Tanh()
|
||||
self.tanh = P.Tanh()
|
||||
|
||||
def construct(self, x):
|
||||
return self.tanh(x)
|
||||
|
@ -415,7 +414,7 @@ class GELU(Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(GELU, self).__init__()
|
||||
self.gelu = _selected_ops.GeLU()
|
||||
self.gelu = P.GeLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.gelu(x)
|
||||
|
@ -458,7 +457,7 @@ class FastGelu(Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(FastGelu, self).__init__()
|
||||
self.fast_gelu = _selected_ops.FastGeLU()
|
||||
self.fast_gelu = P.FastGeLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.fast_gelu(x)
|
||||
|
|
|
@ -30,7 +30,6 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.communication import management
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from ..cell import Cell
|
||||
|
||||
|
@ -837,9 +836,9 @@ class LayerNorm(Cell):
|
|||
gamma_init, normalized_shape), name="gamma")
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, normalized_shape), name="beta")
|
||||
self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
|
||||
begin_params_axis=self.begin_params_axis,
|
||||
epsilon=self.epsilon)
|
||||
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
|
||||
begin_params_axis=self.begin_params_axis,
|
||||
epsilon=self.epsilon)
|
||||
|
||||
def construct(self, input_x):
|
||||
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||
|
|
|
@ -21,7 +21,6 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore import nn
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore._checkparam import Validator as validator
|
||||
|
@ -48,7 +47,7 @@ class _Loss(Cell):
|
|||
if reduction == 'none':
|
||||
self.reduce = False
|
||||
|
||||
self.reduce_mean = _selected_ops.ReduceMean()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
|
@ -381,7 +380,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
||||
self.sparse = validator.check_bool(sparse, "sparse")
|
||||
self.reduction = reduction
|
||||
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
|
||||
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0., mstype.float32)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""momentum"""
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -157,7 +156,7 @@ class Momentum(Optimizer):
|
|||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
|
||||
def construct(self, gradients):
|
||||
params = self.params
|
||||
|
|
|
@ -19,7 +19,6 @@ from functools import reduce
|
|||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import nn
|
||||
from mindspore.ops import _selected_grad_ops as SG
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _grad_ops as G
|
||||
|
@ -590,7 +589,7 @@ def get_bprop_expm1(self):
|
|||
@bprop_getters.register(P.Minimum)
|
||||
def get_bprop_minimum(self):
|
||||
"""Grad definition for `Minimum` operation."""
|
||||
input_grad = SG.MinimumGrad()
|
||||
input_grad = G.MinimumGrad()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
dx, dy = input_grad(x, y, dout)
|
||||
|
@ -602,7 +601,7 @@ def get_bprop_minimum(self):
|
|||
@bprop_getters.register(P.Maximum)
|
||||
def get_bprop_maximum(self):
|
||||
"""Grad definition for `Maximum` operation."""
|
||||
input_grad = SG.MaximumGrad()
|
||||
input_grad = G.MaximumGrad()
|
||||
|
||||
def bprop(x, y, out, dout):
|
||||
dx, dy = input_grad(x, y, dout)
|
||||
|
@ -1107,7 +1106,7 @@ def get_bprop_cosh(self):
|
|||
@bprop_getters.register(P.Abs)
|
||||
def get_bprop_abs(self):
|
||||
"""Grad definition for `Abs` operation."""
|
||||
abs_grad = SG.AbsGrad()
|
||||
abs_grad = G.AbsGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = abs_grad(x, dout)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
"""Define the grad rules of neural network related operations."""
|
||||
import os
|
||||
from mindspore.ops import _selected_grad_ops as SG
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
|
@ -34,7 +33,7 @@ env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
|
|||
@bprop_getters.register(P.BiasAdd)
|
||||
def get_bprop_bias_add(self):
|
||||
"""Grad definition for `BiasAdd` operation."""
|
||||
bias_grad = SG.BiasAddGrad(self.data_format)
|
||||
bias_grad = G.BiasAddGrad(self.data_format)
|
||||
|
||||
def bprop(x, w, out, dout):
|
||||
return dout, bias_grad(dout)
|
||||
|
@ -341,7 +340,7 @@ def get_bprop_dropout_do_mask(self):
|
|||
def get_bprop_mish(self):
|
||||
"""Grad definition for `Mish` operation."""
|
||||
tanh = P.Tanh()
|
||||
tanh_grad = SG.TanhGrad()
|
||||
tanh_grad = G.TanhGrad()
|
||||
softplus = P.Softplus()
|
||||
softplus_grad = G.SoftplusGrad()
|
||||
|
||||
|
@ -580,7 +579,7 @@ def get_bprop_softsign(self):
|
|||
@bprop_getters.register(P.Tanh)
|
||||
def get_bprop_tanh(self):
|
||||
"""Grad definition for `Tanh` operation."""
|
||||
tanh_grad = SG.TanhGrad()
|
||||
tanh_grad = G.TanhGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = tanh_grad(out, dout)
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" resolved grad ops """
|
||||
from mindspore.ops.op_selector import new_ops_selector
|
||||
|
||||
op_selector = new_ops_selector(
|
||||
"mindspore.ops.operations._grad_ops", "mindspore.nn._graph_kernels")
|
||||
|
||||
|
||||
@op_selector
|
||||
class MaximumGrad:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class MinimumGrad:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class AbsGrad:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class BiasAddGrad:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class TanhGrad:
|
||||
def __call__(self, *args):
|
||||
pass
|
|
@ -1,120 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
""" resolve ops """
|
||||
from mindspore.ops.op_selector import new_ops_selector
|
||||
|
||||
op_selector = new_ops_selector(
|
||||
"mindspore.ops.operations", "mindspore.nn._graph_kernels")
|
||||
opt_selector = new_ops_selector(
|
||||
"mindspore.nn.optim", "mindspore.nn._graph_kernels")
|
||||
nn_selector = new_ops_selector(
|
||||
"mindspore.nn", "mindspore.nn._graph_kernels")
|
||||
|
||||
|
||||
@nn_selector
|
||||
class BatchNorm2d:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class ReLU:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class ReduceMean:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class BiasAdd:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class ApplyMomentum:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class SoftmaxCrossEntropyWithLogits:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LogSoftmax:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class Tanh:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class GeLU:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class FastGeLU:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LayerNorm:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class Softmax:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambUpdateWithLR:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambNextMV:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambApplyOptimizerAssign:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class LambApplyWeightAssign:
|
||||
def __call__(self, *args):
|
||||
pass
|
|
@ -20,6 +20,7 @@ which can be used to control the switch of op type: GraphKernel or Primitive.
|
|||
import importlib
|
||||
import inspect
|
||||
from mindspore import context
|
||||
from mindspore.common._decorator import deprecated
|
||||
|
||||
|
||||
class _OpSelector:
|
||||
|
@ -70,6 +71,7 @@ class _OpSelector:
|
|||
return op(*args, **kwargs)
|
||||
|
||||
|
||||
@deprecated("1.3", "basic Primitive", False)
|
||||
def new_ops_selector(primitive_pkg, graph_kernel_pkg):
|
||||
"""
|
||||
A factory method to return an op selector
|
||||
|
@ -83,6 +85,8 @@ def new_ops_selector(primitive_pkg, graph_kernel_pkg):
|
|||
The order of the highest priority to lowest priority is (1), (2), (3)
|
||||
If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
|
||||
|
||||
The user-defined GraphKernel Cell is deprecated, this interface will be removed in a future version.
|
||||
|
||||
Args:
|
||||
primitive_pkg (str): primitive op's package name
|
||||
graph_kernel_pkg (str): graph kernel op's package name
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Other operators."""
|
||||
import functools
|
||||
from mindspore.common import monad
|
||||
from mindspore.common._decorator import deprecated
|
||||
from .. import signature as sig
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -84,6 +85,8 @@ class InplaceAssign(PrimitiveWithInfer):
|
|||
Inplace assign `Parameter` with a value.
|
||||
This primitive can only use in graph kernel.
|
||||
|
||||
InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead.
|
||||
|
||||
Inputs:
|
||||
- **variable** (Parameter) - The `Parameter`.
|
||||
- **value** (Tensor) - The value to be assigned.
|
||||
|
@ -110,7 +113,8 @@ class InplaceAssign(PrimitiveWithInfer):
|
|||
>>> net = Net()
|
||||
>>> output = net(x)
|
||||
>>> print(output)
|
||||
"""
|
||||
"""
|
||||
@deprecated("1.3", "Assign", False)
|
||||
@ prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
|
||||
|
|
|
@ -26,7 +26,6 @@ from mindspore.nn.loss.loss import _Loss
|
|||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.parallel._utils import _reset_op_id
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
|
@ -76,7 +75,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
|||
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
||||
self.sparse = sparse
|
||||
self.reduction = reduction
|
||||
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
|
||||
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0., mstype.float32)
|
||||
|
|
Loading…
Reference in New Issue