forked from mindspore-Ecosystem/mindspore
Clean GraphKernel's codes from frontend
1. set class GraphKernel as deprecated, and treat it as Cell 2. set class InplaceAssign as deprecated, suggested using Assign instead. 3. set op_selector as deprecated, removed the _selected_ops and _selected_grad_ops, replaced with real operations 4. removed the two passes of GraphKernel from frontend 5. removed the GraphKernel's codes from other modules
This commit is contained in:
parent
c3f3fcab3f
commit
771e3f61f3
|
@ -17,6 +17,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
#include "backend/optimizer/ascend/ir_fission/dynamic_rnn_grad_fission_v2.h"
|
||||||
|
@ -110,8 +111,6 @@
|
||||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/split_fission.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/splitv_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/splitv_fission.h"
|
||||||
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
|
|
||||||
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
|
|
||||||
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
|
||||||
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
|
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
|
||||||
|
@ -199,19 +198,6 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
ir_fusion_pm->AddPass(std::make_shared<BCEWithLogitsLossFission>());
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
||||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
|
||||||
MS_EXCEPTION_IF_NULL(optimizer);
|
|
||||||
auto common_process = std::make_shared<PassManager>("graph_kernel_common_process");
|
|
||||||
MS_EXCEPTION_IF_NULL(common_process);
|
|
||||||
common_process->AddPass(std::make_shared<ModifyOpAttrs>());
|
|
||||||
common_process->AddPass(std::make_shared<RemoveNoUseReshapeOp>());
|
|
||||||
optimizer->AddPassManager(common_process);
|
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
|
||||||
kernel_graph->SetExecOrderByDefault();
|
|
||||||
}
|
|
||||||
|
|
||||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
auto optimizer = std::make_shared<GraphOptimizer>();
|
auto optimizer = std::make_shared<GraphOptimizer>();
|
||||||
|
|
|
@ -24,7 +24,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendGraphKernelCommonProcess(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
|
||||||
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -344,7 +344,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
||||||
|
|
||||||
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
|
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
|
||||||
auto real_input_node = kernel_with_index.first;
|
auto real_input_node = kernel_with_index.first;
|
||||||
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
if (kernel::IsWeightBoundary(real_input_node)) {
|
||||||
// weight
|
// weight
|
||||||
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
|
||||||
if (origin_type == kTypeUnknown) {
|
if (origin_type == kTypeUnknown) {
|
||||||
|
@ -358,9 +358,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
||||||
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
||||||
// In graph kernel, we check parameter,
|
// In graph kernel, we check parameter,
|
||||||
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
||||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||||
new_inputs.push_back(cur_input);
|
|
||||||
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
|
||||||
auto cast =
|
auto cast =
|
||||||
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
|
||||||
MS_EXCEPTION_IF_NULL(cast);
|
MS_EXCEPTION_IF_NULL(cast);
|
||||||
|
|
|
@ -78,21 +78,12 @@ const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePt
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> todos = {node};
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
for (size_t i = 0; i < in_num; ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
||||||
}
|
<< cnode->DebugString() << "]";
|
||||||
|
|
||||||
for (auto &t : todos) {
|
|
||||||
CNodePtr cnode = t->cast<CNodePtr>();
|
|
||||||
size_t in_num = AnfAlgo::GetInputTensorNum(cnode);
|
|
||||||
for (size_t i = 0; i < in_num; ++i) {
|
|
||||||
if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "["
|
|
||||||
<< cnode->DebugString() << "]";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -30,8 +30,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
const std::vector<bool> &need_insert_cast) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
|
@ -42,32 +41,23 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
||||||
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
for (size_t output_idx = 0; output_idx < out_num; ++output_idx) {
|
||||||
AnfNodePtr replace_node = nullptr;
|
AnfNodePtr replace_node = nullptr;
|
||||||
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
|
||||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
const auto origin_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx);
|
||||||
auto idx = NewValueNode(SizeToLong(output_idx));
|
auto idx = NewValueNode(SizeToLong(output_idx));
|
||||||
MS_EXCEPTION_IF_NULL(idx);
|
MS_EXCEPTION_IF_NULL(idx);
|
||||||
auto imm = std::make_shared<Int64Imm>(output_idx);
|
auto imm = std::make_shared<Int64Imm>(output_idx);
|
||||||
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
|
||||||
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
|
||||||
AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get());
|
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, getitem.get());
|
||||||
if (need_insert_cast[output_idx]) {
|
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
||||||
const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx);
|
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
||||||
TypeId origin_type(kTypeUnknown);
|
if (origin_type != device_type) {
|
||||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
|
||||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
origin_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
|
||||||
}
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
replace_node->set_scope(cnode->scope());
|
||||||
const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx);
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||||
if (origin_type != device_type) {
|
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
||||||
replace_node = AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape,
|
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
||||||
infer_type, AnfAlgo::GetOutputReshapeType(getitem, 0));
|
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
|
||||||
replace_node->set_scope(cnode->scope());
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
|
||||||
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode, output_idx)) {
|
|
||||||
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
replace_node = getitem;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
replace_node = getitem;
|
replace_node = getitem;
|
||||||
|
@ -81,8 +71,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
|
||||||
return make_tuple;
|
return make_tuple;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||||
const std::vector<bool> &need_insert_cast) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
if (AnfAlgo::GetOutputTensorNum(cnode) == 0) {
|
||||||
|
@ -92,23 +81,14 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
||||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
// Single output
|
// Single output
|
||||||
if (!cnode->Type()->isa<Tuple>()) {
|
if (!cnode->Type()->isa<Tuple>()) {
|
||||||
if (!need_insert_cast[0]) {
|
|
||||||
return cnode;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0);
|
||||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||||
const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
const TypeId origin_type = AnfAlgo::GetOutputInferDataType(cnode, 0);
|
||||||
TypeId origin_type(kTypeUnknown);
|
|
||||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode);
|
|
||||||
}
|
|
||||||
origin_type = origin_type == kTypeUnknown ? infer_type : origin_type;
|
|
||||||
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0);
|
||||||
AnfNodePtr replace_node = cnode;
|
AnfNodePtr replace_node = cnode;
|
||||||
if (origin_type != device_type) {
|
if (origin_type != device_type) {
|
||||||
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
|
replace_node = AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape,
|
||||||
infer_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
|
origin_type, AnfAlgo::GetOutputReshapeType(cnode, 0));
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
replace_node->set_scope(cnode->scope());
|
replace_node->set_scope(cnode->scope());
|
||||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
|
||||||
|
@ -119,69 +99,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
|
||||||
return replace_node;
|
return replace_node;
|
||||||
}
|
}
|
||||||
// Multiple output
|
// Multiple output
|
||||||
return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast);
|
return InsertCastForMultipleOutput(func_graph, cnode);
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|
||||||
// insert cast for ops in graph kernel.
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
auto mng = sub_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
std::vector<AnfNodePtr> todo;
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
|
||||||
auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem});
|
|
||||||
std::vector<std::pair<AnfNodePtr, size_t>> graph_rets;
|
|
||||||
for (auto &output : outputs) {
|
|
||||||
size_t index = 0;
|
|
||||||
if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
|
||||||
ValuePtr tuple_index_value = GetValueNode(output->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_index_value);
|
|
||||||
if (!tuple_index_value->isa<Int64Imm>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64";
|
|
||||||
}
|
|
||||||
index = tuple_index_value->cast<Int64ImmPtr>()->value();
|
|
||||||
}
|
|
||||||
graph_rets.emplace_back(std::pair<AnfNodePtr, size_t>(output, index));
|
|
||||||
}
|
|
||||||
for (auto &t : todo) {
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t);
|
|
||||||
// process input
|
|
||||||
CNodePtr t_cnode = t->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(t_cnode);
|
|
||||||
auto t_new_node = InsertCastForInput(sub_graph, t_cnode);
|
|
||||||
AnfNodePtr t_new_node_1 = nullptr;
|
|
||||||
std::vector<bool> need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true);
|
|
||||||
// process output
|
|
||||||
auto iter = std::find_if(graph_rets.begin(), graph_rets.end(),
|
|
||||||
[&t](const std::pair<AnfNodePtr, size_t> &ret) { return ret.first == t; });
|
|
||||||
if (iter != graph_rets.end()) {
|
|
||||||
auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t);
|
|
||||||
auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second);
|
|
||||||
auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin());
|
|
||||||
if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) {
|
|
||||||
need_insert_cast[iter->second] = false;
|
|
||||||
} else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) {
|
|
||||||
need_insert_cast[iter->second] = false;
|
|
||||||
}
|
|
||||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
|
||||||
} else {
|
|
||||||
t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (t_new_node_1 != nullptr && t_new_node_1 != t) {
|
|
||||||
(void)mng->Replace(t, t_new_node_1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert cast for graph kernel.
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
|
||||||
// process input
|
|
||||||
CNodePtr cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto new_node = InsertCastForInput(func_graph, cnode);
|
|
||||||
// process output
|
|
||||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -196,11 +114,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
||||||
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
|
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
return ProcessGraphKernelOp(func_graph, node);
|
|
||||||
}
|
|
||||||
// insert cast for single op.
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||||
// process input
|
// process input
|
||||||
CNodePtr cnode = node->cast<CNodePtr>();
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
|
@ -211,7 +124,7 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
|
||||||
kernel_graph->ReplaceInternalOutput(node, new_node);
|
kernel_graph->ReplaceInternalOutput(node, new_node);
|
||||||
}
|
}
|
||||||
// process output
|
// process output
|
||||||
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
|
return InsertCastForOutput(func_graph, new_node);
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -150,9 +150,6 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto next_cnode = next_node->cast<CNodePtr>();
|
auto next_cnode = next_node->cast<CNodePtr>();
|
||||||
if (AnfAlgo::IsGraphKernel(next_node)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
|
auto next_op_name = AnfAlgo::GetCNodeName(next_cnode);
|
||||||
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
|
if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -224,9 +221,6 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(prior_op);
|
MS_EXCEPTION_IF_NULL(prior_op);
|
||||||
if (AnfAlgo::IsGraphKernel(prior_op)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
|
||||||
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||
|
if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() ||
|
||||||
|
|
|
@ -1,99 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "utils/utils.h"
|
|
||||||
#include "backend/optimizer/common/helper.h"
|
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
|
||||||
#include "base/core_ops.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
namespace {
|
|
||||||
AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) {
|
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
|
||||||
auto input_format = AnfAlgo::GetInputFormat(cnode, 0);
|
|
||||||
if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode);
|
|
||||||
return cnode;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) {
|
|
||||||
auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0);
|
|
||||||
if (input_shape.size() != 5) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto multiples = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, kAttrMultiples);
|
|
||||||
if (multiples.size() == 4 && multiples[1] == 1) {
|
|
||||||
multiples.push_back(1);
|
|
||||||
AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode);
|
|
||||||
}
|
|
||||||
|
|
||||||
return cnode;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr ModifyAttrs(const CNodePtr &cnode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
|
||||||
if (op_name == prim::kPrimTile->name()) {
|
|
||||||
return ModifyTileOpAttrs(cnode);
|
|
||||||
} else if (op_name == prim::kPrimReduceSum->name()) {
|
|
||||||
// kPrimReduceMean
|
|
||||||
// kPrimReduceSum
|
|
||||||
// kPrimReduceAll
|
|
||||||
// kPrimReduceMax
|
|
||||||
// kPrimReduceMin
|
|
||||||
return ModifyReduceOpsAttrs(cnode);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
||||||
const EquivPtr &) const {
|
|
||||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node);
|
|
||||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto manager = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
kernel::GetValidKernelNodes(fg, &todos);
|
|
||||||
for (auto &t : todos) {
|
|
||||||
auto new_node = ModifyAttrs(t->cast<CNodePtr>());
|
|
||||||
if (new_node != nullptr && new_node != t) {
|
|
||||||
(void)manager->Replace(t, new_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,33 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
|
||||||
|
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class ModifyOpAttrs : public PatternProcessPass {
|
|
||||||
public:
|
|
||||||
explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {}
|
|
||||||
~ModifyOpAttrs() override = default;
|
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H
|
|
|
@ -1,66 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "backend/optimizer/common/helper.h"
|
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
|
||||||
#include "base/core_ops.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
namespace {
|
|
||||||
AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) {
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
|
||||||
if (op_name != prim::kPrimReshape->name()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
|
||||||
auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0);
|
|
||||||
if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return cnode->input(1);
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
|
||||||
const EquivPtr &) const {
|
|
||||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node);
|
|
||||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto manager = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
kernel::GetValidKernelNodes(fg, &todos);
|
|
||||||
for (auto &t : todos) {
|
|
||||||
auto new_node = RemoveReshapeOp(t->cast<CNodePtr>());
|
|
||||||
if (new_node != nullptr && new_node != t) {
|
|
||||||
(void)manager->Replace(t, new_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,33 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
|
||||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
|
||||||
|
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace opt {
|
|
||||||
class RemoveNoUseReshapeOp : public PatternProcessPass {
|
|
||||||
public:
|
|
||||||
explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {}
|
|
||||||
~RemoveNoUseReshapeOp() override = default;
|
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H
|
|
|
@ -122,9 +122,6 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f
|
||||||
if (node == nullptr || !node->isa<CNode>()) {
|
if (node == nullptr || !node->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::vector<CNodePtr> cast_nodes;
|
std::vector<CNodePtr> cast_nodes;
|
||||||
|
|
|
@ -596,20 +596,21 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
||||||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||||
#if ENABLE_D
|
#if ENABLE_D
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
|
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
|
||||||
|
prim::KPrimTransData};
|
||||||
#elif ENABLE_GPU
|
#elif ENABLE_GPU
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||||
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||||
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
||||||
prim::kPrimLess};
|
prim::kPrimLess, prim::kPrimInplaceAssign};
|
||||||
#else
|
#else
|
||||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -81,13 +81,9 @@ const AnfNodePtr AddTrainingAttr::Process(const FuncGraphPtr &func_graph, const
|
||||||
if (iter == MarkOp.end()) {
|
if (iter == MarkOp.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
auto cnode = node->cast<CNodePtr>();
|
||||||
return nullptr;
|
AddAttrTraining(func_graph, cnode);
|
||||||
} else {
|
return cnode;
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
AddAttrTraining(func_graph, cnode);
|
|
||||||
return cnode;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
|
#include "backend/optimizer/pass/convert_attr_to_unify_mindir.h"
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
@ -29,32 +28,22 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
|
||||||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
|
||||||
} else {
|
|
||||||
todos.push_back(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &t : todos) {
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
CNodePtr cnode = t->cast<CNodePtr>();
|
auto inputs = cnode->inputs();
|
||||||
auto inputs = cnode->inputs();
|
AnfNodePtr op = inputs[0];
|
||||||
AnfNodePtr op = inputs[0];
|
if (IsValueNode<Primitive>(op)) {
|
||||||
if (IsValueNode<Primitive>(op)) {
|
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
auto attrs = prim->attrs();
|
||||||
auto attrs = prim->attrs();
|
std::string type_name = prim->name();
|
||||||
std::string type_name = prim->name();
|
for (auto attr : attrs) {
|
||||||
for (auto attr : attrs) {
|
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
|
||||||
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr.second);
|
if (converted) {
|
||||||
if (converted) {
|
prim->set_attr(attr.first, attr.second);
|
||||||
prim->set_attr(attr.first, attr.second);
|
}
|
||||||
}
|
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
|
||||||
bool converted_ir_attr = CheckAndConvertUtils::CheckIrAttrtoOpAttr(type_name, attr.first, &attr.second);
|
if (converted_ir_attr) {
|
||||||
if (converted_ir_attr) {
|
prim->set_attr(attr.first, attr.second);
|
||||||
prim->set_attr(attr.first, attr.second);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/optimizer/pass/convert_const_input_to_attr.h"
|
#include "backend/optimizer/pass/convert_const_input_to_attr.h"
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -34,40 +33,31 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
||||||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
|
||||||
} else {
|
|
||||||
todos.push_back(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &t : todos) {
|
CNodePtr cnode = node->cast<CNodePtr>();
|
||||||
CNodePtr cnode = t->cast<CNodePtr>();
|
ConstInputToAttrInfoRegister reg;
|
||||||
ConstInputToAttrInfoRegister reg;
|
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
||||||
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) {
|
return nullptr;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
|
||||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
|
||||||
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
|
||||||
auto ms_context = MsContext::GetInstance();
|
|
||||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (AnfAlgo::IsDynamicShape(cnode) &&
|
|
||||||
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
|
|
||||||
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
|
||||||
|
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimGatherD->name()) {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(cnode) &&
|
||||||
|
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
|
||||||
|
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||||
|
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -100,24 +100,6 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
auto mng = sub_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
std::vector<AnfNodePtr> todo;
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
|
||||||
|
|
||||||
for (auto &t : todo) {
|
|
||||||
auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast<CNodePtr>());
|
|
||||||
if (t_new_node != nullptr && t_new_node != t) {
|
|
||||||
(void)mng->Replace(t, t_new_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
@ -129,11 +111,8 @@ const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &fun
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
return ProcessGraphKernelOp(node);
|
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
|
||||||
} else {
|
|
||||||
return ConstInputToTensorInput(func_graph, node->cast<CNodePtr>());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -95,15 +95,6 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
|
||||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todos);
|
|
||||||
for (auto &t : todos) {
|
|
||||||
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,26 +170,6 @@ const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, c
|
||||||
if (cnode == nullptr || func_graph == nullptr) {
|
if (cnode == nullptr || func_graph == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
// do eliminate for ops in graph kernel.
|
|
||||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
auto mng = sub_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
std::vector<AnfNodePtr> todo;
|
|
||||||
kernel::GetValidKernelNodes(sub_graph, &todo);
|
|
||||||
for (auto &t : todo) {
|
|
||||||
CNodePtr t_cnode = t->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(t_cnode);
|
|
||||||
auto t_new_node = DoEliminate(sub_graph, t_cnode);
|
|
||||||
if (t_new_node != nullptr && t_new_node != t) {
|
|
||||||
(void)mng->Replace(t, t_new_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
// do eliminate for single op.
|
|
||||||
return DoEliminate(func_graph, cnode);
|
return DoEliminate(func_graph, cnode);
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -30,20 +30,7 @@ const BaseRef EraseVisitAttr::DefinePattern() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) {
|
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
||||||
if (AnfAlgo::IsGraphKernel(node)) {
|
|
||||||
auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
std::vector<AnfNodePtr> todos;
|
|
||||||
kernel::GetValidKernelNodes(fg, &todos);
|
|
||||||
for (auto &t : todos) {
|
|
||||||
AnfAlgo::EraseNodeAttr(kAttrVisited, t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
|
||||||
} else {
|
|
||||||
AnfAlgo::EraseNodeAttr(kAttrVisited, node);
|
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
|
|
|
@ -949,7 +949,6 @@ void AscendSession::InitRuntimeResource() {
|
||||||
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||||
MS_LOG(INFO) << "HardwareOptimize start!";
|
MS_LOG(INFO) << "HardwareOptimize start!";
|
||||||
opt::AscendBackendOptimization(kernel_graph);
|
opt::AscendBackendOptimization(kernel_graph);
|
||||||
opt::AscendGraphKernelCommonProcess(kernel_graph);
|
|
||||||
GraphKernelOptimize(kernel_graph);
|
GraphKernelOptimize(kernel_graph);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
kernel_graph->SetExecOrderByDefault();
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
|
|
@ -433,7 +433,9 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||||
auto cnode = FuncGraph::NewCNode(inputs);
|
auto cnode = FuncGraph::NewCNode(inputs);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
|
||||||
CreateKernelInfoFromNewParameter(cnode);
|
if (AnfAlgo::IsGraphKernel(cnode)) {
|
||||||
|
CreateKernelInfoFromNewParameter(cnode);
|
||||||
|
}
|
||||||
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
|
||||||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||||
}
|
}
|
||||||
|
@ -443,9 +445,6 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
|
||||||
if (!AnfAlgo::IsGraphKernel(cnode)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
|
||||||
|
|
|
@ -45,10 +45,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
||||||
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
|
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
|
||||||
k_graph_ = std::make_shared<FuncGraph>();
|
k_graph_ = std::make_shared<FuncGraph>();
|
||||||
}
|
}
|
||||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
|
||||||
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
|
||||||
}
|
|
||||||
// To keep switch_layer's inputs from being inlined
|
// To keep switch_layer's inputs from being inlined
|
||||||
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
||||||
k_graph_->set_stage(primal_graph->stage());
|
k_graph_->set_stage(primal_graph->stage());
|
||||||
|
@ -58,11 +54,6 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
||||||
tape_ = std::make_shared<FuncGraph>();
|
tape_ = std::make_shared<FuncGraph>();
|
||||||
}
|
}
|
||||||
tape_->set_stage(primal_graph->stage());
|
tape_->set_stage(primal_graph->stage());
|
||||||
// Add "_Grad" postfix
|
|
||||||
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";
|
|
||||||
tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
dout_ = tape_->add_parameter();
|
dout_ = tape_->add_parameter();
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,8 +194,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
// Incorporation
|
// Incorporation
|
||||||
incorporate_getitem_set_ =
|
incorporate_getitem_set_ =
|
||||||
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
|
||||||
incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
|
|
||||||
"incorporate_getitem_from_param", IsCNodeGraphKernel);
|
|
||||||
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
|
incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
|
||||||
incorporate_call_switch_ =
|
incorporate_call_switch_ =
|
||||||
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
|
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
|
||||||
|
@ -211,19 +209,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
print_tuple_wrapper_ =
|
print_tuple_wrapper_ =
|
||||||
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
|
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
|
||||||
|
|
||||||
// Unused parameter eliminate
|
|
||||||
unused_parameter_eliminate_ =
|
|
||||||
MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
|
|
||||||
unused_output_eliminate_ =
|
|
||||||
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
|
|
||||||
|
|
||||||
// tuple parameter graph transform
|
// tuple parameter graph transform
|
||||||
call_graph_tuple_transform_ =
|
call_graph_tuple_transform_ =
|
||||||
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
|
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
|
||||||
|
|
||||||
// AddN eliminate
|
|
||||||
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
|
|
||||||
|
|
||||||
// RowTensor Eliminate
|
// RowTensor Eliminate
|
||||||
row_tensor_eliminate_ = MakeSubstitution(
|
row_tensor_eliminate_ = MakeSubstitution(
|
||||||
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
|
std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
|
||||||
|
|
|
@ -112,7 +112,6 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// Incorporation
|
// Incorporation
|
||||||
SubstitutionPtr incorporate_getitem_set_;
|
SubstitutionPtr incorporate_getitem_set_;
|
||||||
SubstitutionPtr incorporate_getitem_from_param_;
|
|
||||||
SubstitutionPtr incorporate_call_;
|
SubstitutionPtr incorporate_call_;
|
||||||
SubstitutionPtr incorporate_call_switch_;
|
SubstitutionPtr incorporate_call_switch_;
|
||||||
|
|
||||||
|
@ -125,16 +124,9 @@ class OptimizeIRPassLib {
|
||||||
// Convert
|
// Convert
|
||||||
SubstitutionPtr print_tuple_wrapper_;
|
SubstitutionPtr print_tuple_wrapper_;
|
||||||
|
|
||||||
// Unused parameter eliminate
|
|
||||||
SubstitutionPtr unused_parameter_eliminate_;
|
|
||||||
SubstitutionPtr unused_output_eliminate_;
|
|
||||||
|
|
||||||
// tuple parameter graph transform
|
// tuple parameter graph transform
|
||||||
SubstitutionPtr call_graph_tuple_transform_;
|
SubstitutionPtr call_graph_tuple_transform_;
|
||||||
|
|
||||||
// AddN eliminate
|
|
||||||
SubstitutionPtr addn_eliminate_;
|
|
||||||
|
|
||||||
// RowTensor Eliminate
|
// RowTensor Eliminate
|
||||||
SubstitutionPtr row_tensor_eliminate_;
|
SubstitutionPtr row_tensor_eliminate_;
|
||||||
|
|
||||||
|
@ -214,23 +206,6 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
||||||
return IsValueNode<FuncGraph>(inp0);
|
return IsValueNode<FuncGraph>(inp0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if CNode Input 0 is Func Graph of graph kernel.
|
|
||||||
inline bool IsCNodeGraphKernel(const AnfNodePtr &node) {
|
|
||||||
if (node == nullptr || !node->isa<CNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto inp0 = node->cast<CNodePtr>()->input(0);
|
|
||||||
if (IsValueNode<FuncGraph>(inp0)) {
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(inp0);
|
|
||||||
if (fg == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if CNode Input 0 is CNode
|
// Check if CNode Input 0 is CNode
|
||||||
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||||
if (node == nullptr || !node->isa<CNode>()) {
|
if (node == nullptr || !node->isa<CNode>()) {
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
@ -135,25 +134,6 @@ class IncorporateGetitem : public AnfVisitor {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
// If graph kernel has muti output, do not split.
|
|
||||||
// some graph kernel output has EnvInstance node or DeadCode node should split.
|
|
||||||
auto output = fg_->output();
|
|
||||||
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
|
||||||
auto output_cnode = output->cast<CNodePtr>();
|
|
||||||
auto outputs = output_cnode->inputs();
|
|
||||||
int64_t real_output_cnt = 0;
|
|
||||||
for (size_t i = 1; i < outputs.size(); ++i) {
|
|
||||||
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
|
|
||||||
real_output_cnt++;
|
|
||||||
if (real_output_cnt > 1) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_fg = getitem_transform_(fg_, idx_);
|
auto new_fg = getitem_transform_(fg_, idx_);
|
||||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
||||||
return node->func_graph()->NewCNode(args_);
|
return node->func_graph()->NewCNode(args_);
|
||||||
|
@ -184,171 +164,6 @@ class IncorporateGetitem : public AnfVisitor {
|
||||||
internal::GetitemTransform getitem_transform_;
|
internal::GetitemTransform getitem_transform_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IncorporateGetitemFromParam : public AnfVisitor {
|
|
||||||
public:
|
|
||||||
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) {
|
|
||||||
auto mng = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
auto &node_users = mng->node_users();
|
|
||||||
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
|
|
||||||
args_.push_back(cnode->input(input_idx + 1));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &user : node_users[param]) {
|
|
||||||
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
|
|
||||||
// we do not process this case.
|
|
||||||
args_.push_back(cnode->input(input_idx + 1));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update new args.
|
|
||||||
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
|
|
||||||
// case 1
|
|
||||||
replace_parameters_[input_idx] = true;
|
|
||||||
need_update_ = true;
|
|
||||||
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
|
||||||
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
|
|
||||||
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
|
|
||||||
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
|
|
||||||
} else {
|
|
||||||
// case 2
|
|
||||||
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
|
|
||||||
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
|
|
||||||
auto fg_output = prev_fg->output();
|
|
||||||
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
|
|
||||||
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
|
|
||||||
<< " should be a make tuple, but got: " << fg_output->DebugString();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
replace_parameters_[input_idx] = true;
|
|
||||||
need_update_ = true;
|
|
||||||
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
|
|
||||||
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
|
|
||||||
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
|
|
||||||
auto new_getitem =
|
|
||||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToLong(output_i))});
|
|
||||||
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(SizeToLong(output_i)));
|
|
||||||
new_getitem->input(2)->set_abstract(aptr);
|
|
||||||
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
|
|
||||||
args_.push_back(new_getitem);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
||||||
if (node->func_graph() == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Reset();
|
|
||||||
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
if (cnode == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto &inputs = cnode->inputs();
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
|
||||||
if (fg == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto mng = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
auto parameters = fg->parameters();
|
|
||||||
if (parameters.size() != inputs.size() - 1) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
replace_parameters_ = std::vector<bool>(parameters.size(), false);
|
|
||||||
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
|
|
||||||
auto node_fg = node->func_graph();
|
|
||||||
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
||||||
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) {
|
|
||||||
Process(node_fg, cnode, parameters[i - 1], i - 1);
|
|
||||||
} else {
|
|
||||||
args_.push_back(inputs[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!need_update_) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
|
||||||
mng->AddFuncGraph(new_fg);
|
|
||||||
|
|
||||||
auto node_users = mng->node_users();
|
|
||||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
|
||||||
std::vector<AnfNodePtr> new_parameters;
|
|
||||||
size_t curr_input_idx{0};
|
|
||||||
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
|
|
||||||
if (!replace_parameters_[param_i]) {
|
|
||||||
if (parameters[param_i]->abstract() != nullptr) {
|
|
||||||
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
|
|
||||||
}
|
|
||||||
new_parameters.push_back(new_fg_parameters[param_i]);
|
|
||||||
curr_input_idx++;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// make a new parameter.
|
|
||||||
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
|
|
||||||
auto new_param = std::make_shared<Parameter>(new_fg);
|
|
||||||
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
|
|
||||||
|
|
||||||
// update users of new parameter.
|
|
||||||
for (auto &user : node_users[new_fg_parameters[param_i]]) {
|
|
||||||
idx_ = -1;
|
|
||||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int64Imm>})(user.first);
|
|
||||||
if (idx_ == -1) {
|
|
||||||
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
|
|
||||||
<< " must be tuple getitem here, but got: " << user.first->DebugString();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (input_i == LongToSize(idx_)) {
|
|
||||||
for (auto &sub_user : node_users[user.first]) {
|
|
||||||
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_user_cnode);
|
|
||||||
sub_user_cnode->set_input(sub_user.second, new_param);
|
|
||||||
(void)mng->Replace(sub_user.first, sub_user_cnode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
new_parameters.push_back(new_param);
|
|
||||||
curr_input_idx++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mng->SetParameters(new_fg, new_parameters);
|
|
||||||
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
|
|
||||||
auto new_call = node_fg->NewCNode(args_);
|
|
||||||
new_call->set_abstract(node->abstract());
|
|
||||||
return new_call;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int64_t>(vnode->value()); }
|
|
||||||
|
|
||||||
void Visit(const CNodePtr &cnode) override {}
|
|
||||||
|
|
||||||
void Reset() {
|
|
||||||
replace_parameters_.clear();
|
|
||||||
args_.clear();
|
|
||||||
inputs_num_.clear();
|
|
||||||
need_update_ = false;
|
|
||||||
idx_ = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<bool> replace_parameters_{};
|
|
||||||
std::vector<AnfNodePtr> args_{};
|
|
||||||
std::vector<size_t> inputs_num_{};
|
|
||||||
bool need_update_{false};
|
|
||||||
int64_t idx_{-1};
|
|
||||||
};
|
|
||||||
|
|
||||||
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
|
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
|
||||||
class IncorporateGetitemSwitch : public AnfVisitor {
|
class IncorporateGetitemSwitch : public AnfVisitor {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -103,13 +103,6 @@ class InlinerBase : public AnfVisitor {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not inline GraphKernel to Cell.
|
|
||||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
// If the GraphKernel only contains a return node, we make it inlined.
|
|
||||||
if (fg->nodes().size() - fg->parameters().size() > 1) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reset();
|
Reset();
|
||||||
|
|
||||||
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
|
// 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
|
||||||
|
|
|
@ -205,130 +205,6 @@ class AddNZeroFilter : public AnfVisitor {
|
||||||
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
|
||||||
bool has_zero_like_{false};
|
bool has_zero_like_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
// {PrimAddN, {kPrimMakeTuple, Xs}}
|
|
||||||
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
|
|
||||||
// case0: AddN(inputs)(inputs size < 2) -> error
|
|
||||||
// case1: AddN(inputs)(all inputs is ValueNode) -> error
|
|
||||||
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
|
|
||||||
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
|
|
||||||
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
|
||||||
class AddNEliminater : public AnfVisitor {
|
|
||||||
public:
|
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
||||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto mng = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
if (fg->recursive()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
|
||||||
mng->AddFuncGraph(new_fg);
|
|
||||||
need_update_ = false;
|
|
||||||
bool changed;
|
|
||||||
do {
|
|
||||||
changed = Process(new_fg);
|
|
||||||
} while (changed);
|
|
||||||
|
|
||||||
if (!need_update_) {
|
|
||||||
return nullptr;
|
|
||||||
} else {
|
|
||||||
auto new_sx = inputs;
|
|
||||||
new_sx[0] = NewValueNode(new_fg);
|
|
||||||
return node->func_graph()->NewCNode(new_sx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Process(const FuncGraphPtr &func_graph) {
|
|
||||||
auto mng = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
auto nodes = TopoSort(func_graph->output());
|
|
||||||
bool changed = false;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
|
||||||
auto node = nodes[i];
|
|
||||||
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto &tuple_input = cnode->input(1);
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
|
||||||
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
|
|
||||||
auto &tuple_inputs = tuple_input_cnode->inputs();
|
|
||||||
if (tuple_inputs.size() < 3) {
|
|
||||||
// case0: inputs size < 2, error
|
|
||||||
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t valuenode_num = std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0,
|
|
||||||
[](int64_t accumulator, const AnfNodePtr &node) {
|
|
||||||
if (IsValueNode<tensor::Tensor>(node)) {
|
|
||||||
return accumulator + 1;
|
|
||||||
} else {
|
|
||||||
return accumulator;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (LongToSize(valuenode_num) == tuple_inputs.size()) {
|
|
||||||
// case1: all inputs is ValueNode, error
|
|
||||||
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tuple_inputs.size() == 3) {
|
|
||||||
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
|
|
||||||
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
|
|
||||||
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
|
||||||
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
|
|
||||||
tuple_inputs[2]};
|
|
||||||
mng->Replace(node, func_graph->NewCNode(new_xs));
|
|
||||||
changed = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
|
||||||
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
|
|
||||||
if (first_valuenode == tuple_inputs.end()) {
|
|
||||||
// no ValueNode input found.
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
|
|
||||||
std::vector<AnfNodePtr> make_tuple_new_xs{
|
|
||||||
NewValueNode(prim::kPrimMakeTuple),
|
|
||||||
};
|
|
||||||
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
|
|
||||||
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
|
|
||||||
if (node != *first_valuenode) {
|
|
||||||
make_tuple_new_xs.push_back(node);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
|
|
||||||
auto new_addn = func_graph->NewCNode(
|
|
||||||
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
|
|
||||||
ValuePtr prim_tensoradd = prim::GetPythonOps("Add", "mindspore.ops.operations");
|
|
||||||
auto new_add =
|
|
||||||
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
|
|
||||||
(void)mng->Replace(node, new_add);
|
|
||||||
changed = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
need_update_ = need_update_ || changed;
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool need_update_{false};
|
|
||||||
};
|
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "frontend/optimizer/irpass.h"
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
@ -162,146 +161,6 @@ class SpecializeOnGraphArguments : public AnfVisitor {
|
||||||
private:
|
private:
|
||||||
internal::SpecializeTransform specialize_transform_;
|
internal::SpecializeTransform specialize_transform_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eliminate unused parameters.
|
|
||||||
// {G, Xs}
|
|
||||||
class UnusedParasEliminater : public AnfVisitor {
|
|
||||||
public:
|
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
||||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto &inputs = cnode->inputs();
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> parameters = fg->parameters();
|
|
||||||
size_t size = parameters.size();
|
|
||||||
if (size != inputs.size() - 1) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> new_xs;
|
|
||||||
std::vector<bool> keep_parameters;
|
|
||||||
auto mng = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
auto &node_users = mng->node_users();
|
|
||||||
bool has_unused_para = false;
|
|
||||||
for (size_t i = 0; i < size; ++i) {
|
|
||||||
auto iter = node_users.find(parameters[i]);
|
|
||||||
if (iter != node_users.end() && !iter->second.empty()) {
|
|
||||||
keep_parameters.push_back(true);
|
|
||||||
new_xs.push_back(inputs[i + 1]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
keep_parameters.push_back(false);
|
|
||||||
has_unused_para = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!has_unused_para) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
|
|
||||||
mng->AddFuncGraph(new_fg);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
|
|
||||||
std::vector<AnfNodePtr> new_parameters;
|
|
||||||
for (size_t i = 0; i < size; i++) {
|
|
||||||
if (keep_parameters[i]) {
|
|
||||||
if (parameters[i]->abstract() != nullptr) {
|
|
||||||
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
|
|
||||||
}
|
|
||||||
new_parameters.push_back(new_fg_parameters[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mng->SetParameters(new_fg, new_parameters);
|
|
||||||
|
|
||||||
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
|
|
||||||
return node->func_graph()->NewCNode(new_xs);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Eliminate unused outputs.
|
|
||||||
// {G, Xs}
|
|
||||||
class UnusedOutputEliminater : public AnfVisitor {
|
|
||||||
public:
|
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
|
||||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
|
||||||
auto mng = fg->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
if (fg->recursive()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
|
|
||||||
mng->AddFuncGraph(new_fg);
|
|
||||||
auto new_fg_output = new_fg->output();
|
|
||||||
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_cnode = new_fg_output->cast<CNodePtr>();
|
|
||||||
auto &node_users = mng->node_users();
|
|
||||||
if (node_users.count(node) == 0 || node_users[node].empty()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
std::unordered_set<int64_t> used_output_idx;
|
|
||||||
std::vector<std::pair<AnfNodePtr, int64_t>> all_users;
|
|
||||||
for (auto &node_user : node_users[node]) {
|
|
||||||
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto user_cnode = node_user.first->cast<CNodePtr>();
|
|
||||||
size_t used_idx = GetValue<int64_t>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
|
|
||||||
used_output_idx.insert(used_idx);
|
|
||||||
all_users.push_back(std::make_pair(node_user.first, used_idx));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
|
|
||||||
// all output has users.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (used_output_idx.empty()) {
|
|
||||||
// we do not process this case.
|
|
||||||
return nullptr;
|
|
||||||
} else if (used_output_idx.size() == 1) {
|
|
||||||
// after eliminate, only one output left.
|
|
||||||
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
|
|
||||||
// update users.
|
|
||||||
for (auto &ret_user : all_users) {
|
|
||||||
(void)mng->Replace(ret_user.first, node);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// after eliminate, create new multi output.
|
|
||||||
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
|
|
||||||
std::unordered_map<int64_t, int64_t> new_idx_map;
|
|
||||||
for (auto idx : used_output_idx) {
|
|
||||||
new_idx_map[idx] = SizeToLong(new_output_inputs.size() - 1);
|
|
||||||
new_output_inputs.push_back(output_cnode->input(idx + 1));
|
|
||||||
}
|
|
||||||
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
|
|
||||||
// update users.
|
|
||||||
for (auto &ret_user : all_users) {
|
|
||||||
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
|
|
||||||
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto new_sx = inputs;
|
|
||||||
new_sx[0] = NewValueNode(new_fg);
|
|
||||||
return node->func_graph()->NewCNode(new_sx);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace irpass
|
} // namespace irpass
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -39,7 +39,6 @@
|
||||||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||||
#include "frontend/optimizer/recompute.h"
|
#include "frontend/optimizer/recompute.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/context/graph_kernel_flags.h"
|
|
||||||
#include "pipeline/jit/pipeline_split.h"
|
#include "pipeline/jit/pipeline_split.h"
|
||||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||||
|
@ -296,31 +295,6 @@ OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &i
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
||||||
OptPassGroupMap map({
|
|
||||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
|
||||||
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
|
|
||||||
});
|
|
||||||
return map;
|
|
||||||
}
|
|
||||||
|
|
||||||
OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|
||||||
opt::OptPassConfig elim_1 = opt::OptPassConfig({
|
|
||||||
irpass.addn_eliminate_,
|
|
||||||
irpass.incorporate_getitem_from_param_,
|
|
||||||
});
|
|
||||||
opt::OptPassConfig elim_2 = opt::OptPassConfig({
|
|
||||||
irpass.unused_parameter_eliminate_,
|
|
||||||
irpass.unused_output_eliminate_,
|
|
||||||
});
|
|
||||||
OptPassGroupMap map({
|
|
||||||
{"elim_1", elim_1},
|
|
||||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
|
||||||
{"elim_2", elim_2},
|
|
||||||
});
|
|
||||||
return map;
|
|
||||||
}
|
|
||||||
|
|
||||||
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
|
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
|
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||||
}
|
}
|
||||||
|
@ -375,10 +349,6 @@ void InitOpt(const ResourcePtr &res) {
|
||||||
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
|
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
|
||||||
g_pass_opts["opt_trans_graph"] =
|
g_pass_opts["opt_trans_graph"] =
|
||||||
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
|
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
|
||||||
g_pass_opts["opt_graph_kernel_a"] =
|
|
||||||
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
|
|
||||||
g_pass_opts["opt_graph_kernel_b"] =
|
|
||||||
Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false);
|
|
||||||
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
|
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
|
||||||
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
|
||||||
g_pass_opts["opt_grad_epilogue"] =
|
g_pass_opts["opt_grad_epilogue"] =
|
||||||
|
@ -386,10 +356,6 @@ void InitOpt(const ResourcePtr &res) {
|
||||||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||||
g_pass_opts["opt_after_recompute"] =
|
g_pass_opts["opt_after_recompute"] =
|
||||||
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
||||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
|
||||||
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
|
|
||||||
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -424,8 +390,6 @@ bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a");
|
||||||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||||
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
||||||
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
|
||||||
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
|
|
||||||
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
|
||||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||||
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
|
||||||
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
|
||||||
|
@ -559,8 +523,6 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
||||||
{"opt_after_cconv", OptPassAfterCconvGroup},
|
{"opt_after_cconv", OptPassAfterCconvGroup},
|
||||||
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
||||||
{"tuple_transform", OptPassTransformGraphGroup},
|
{"tuple_transform", OptPassTransformGraphGroup},
|
||||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
|
||||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
|
||||||
{"add_cache_embedding", AddCacheEmbeddingPass},
|
{"add_cache_embedding", AddCacheEmbeddingPass},
|
||||||
{"add_recomputation", AddRecomputationPass},
|
{"add_recomputation", AddRecomputationPass},
|
||||||
{"cse_after_recomputation", OptAfterRecomputeGroup}};
|
{"cse_after_recomputation", OptAfterRecomputeGroup}};
|
||||||
|
|
|
@ -216,9 +216,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
|
||||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||||
auto context_ptr = MsContext::GetInstance();
|
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
|
||||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
|
||||||
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
||||||
}
|
}
|
||||||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||||
|
|
|
@ -602,12 +602,6 @@ bool GraphPartition::IsCut(const AnfNodePtr &node) {
|
||||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||||
}
|
}
|
||||||
AnfNodePtr fn = inputs[0];
|
AnfNodePtr fn = inputs[0];
|
||||||
if (IsValueNode<FuncGraph>(fn)) {
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(fn);
|
|
||||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!IsValueNode<Primitive>(fn)) {
|
if (!IsValueNode<Primitive>(fn)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -510,7 +510,7 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
|
||||||
MS_EXCEPTION_IF_NULL(prim_graph);
|
MS_EXCEPTION_IF_NULL(prim_graph);
|
||||||
FuncGraphSet graphs = prim_graph->manager()->func_graphs();
|
FuncGraphSet graphs = prim_graph->manager()->func_graphs();
|
||||||
for (auto g : graphs) {
|
for (auto g : graphs) {
|
||||||
if (g != graph && g != nullptr && !(g->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
if (g != graph && g != nullptr) {
|
||||||
Compile(g);
|
Compile(g);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -542,7 +542,7 @@ uint32_t GraphCompiler::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||||
// Compile sub graphs.
|
// Compile sub graphs.
|
||||||
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
||||||
for (auto sub_graph : sub_graphs) {
|
for (auto sub_graph : sub_graphs) {
|
||||||
if (sub_graph != func_graph && sub_graph != nullptr && !(sub_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) {
|
if (sub_graph != func_graph && sub_graph != nullptr) {
|
||||||
(void)CompileGraph(sub_graph);
|
(void)CompileGraph(sub_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import numpy
|
||||||
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
||||||
|
from mindspore.common._decorator import deprecated
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from .. import context
|
from .. import context
|
||||||
from .._c_expression import init_pipeline, Cell_, FuncGraph
|
from .._c_expression import init_pipeline, Cell_, FuncGraph
|
||||||
|
@ -1213,6 +1214,11 @@ class GraphKernel(Cell):
|
||||||
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
|
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when
|
||||||
enable_graph_kernel in context is set to True.
|
enable_graph_kernel in context is set to True.
|
||||||
|
|
||||||
|
This class is deprecated from version 1.3 and will be removed in a future version, use Cell instead.
|
||||||
|
|
||||||
|
GraphKernel is not supported user-defined cells anymore, the `GraphKernel` objects will be treated as
|
||||||
|
normal `Cell` objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auto_prefix (bool): Recursively generate namespaces. Default: True.
|
auto_prefix (bool): Recursively generate namespaces. Default: True.
|
||||||
flags (dict) : Set graph flags. Default: None.
|
flags (dict) : Set graph flags. Default: None.
|
||||||
|
@ -1230,10 +1236,9 @@ class GraphKernel(Cell):
|
||||||
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@deprecated("1.3", "Cell", True)
|
||||||
def __init__(self, auto_prefix=True, flags=None):
|
def __init__(self, auto_prefix=True, flags=None):
|
||||||
super(GraphKernel, self).__init__(auto_prefix, flags)
|
super(GraphKernel, self).__init__(auto_prefix, flags)
|
||||||
class_name = self.__class__.__name__
|
|
||||||
self.add_flags(graph_kernel=class_name)
|
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -87,7 +86,7 @@ class Softmax(Cell):
|
||||||
|
|
||||||
def __init__(self, axis=-1):
|
def __init__(self, axis=-1):
|
||||||
super(Softmax, self).__init__()
|
super(Softmax, self).__init__()
|
||||||
self.softmax = _selected_ops.Softmax(axis)
|
self.softmax = P.Softmax(axis)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.softmax(x)
|
return self.softmax(x)
|
||||||
|
@ -137,7 +136,7 @@ class LogSoftmax(Cell):
|
||||||
|
|
||||||
def __init__(self, axis=-1):
|
def __init__(self, axis=-1):
|
||||||
super(LogSoftmax, self).__init__()
|
super(LogSoftmax, self).__init__()
|
||||||
self.log_softmax = _selected_ops.LogSoftmax(axis)
|
self.log_softmax = P.LogSoftmax(axis)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.log_softmax(x)
|
return self.log_softmax(x)
|
||||||
|
@ -368,7 +367,7 @@ class Tanh(Cell):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Tanh, self).__init__()
|
super(Tanh, self).__init__()
|
||||||
self.tanh = _selected_ops.Tanh()
|
self.tanh = P.Tanh()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.tanh(x)
|
return self.tanh(x)
|
||||||
|
@ -415,7 +414,7 @@ class GELU(Cell):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(GELU, self).__init__()
|
super(GELU, self).__init__()
|
||||||
self.gelu = _selected_ops.GeLU()
|
self.gelu = P.GeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.gelu(x)
|
return self.gelu(x)
|
||||||
|
@ -458,7 +457,7 @@ class FastGelu(Cell):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(FastGelu, self).__init__()
|
super(FastGelu, self).__init__()
|
||||||
self.fast_gelu = _selected_ops.FastGeLU()
|
self.fast_gelu = P.FastGeLU()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
return self.fast_gelu(x)
|
return self.fast_gelu(x)
|
||||||
|
|
|
@ -30,7 +30,6 @@ from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.communication.management import get_group_size, get_rank
|
from mindspore.communication.management import get_group_size, get_rank
|
||||||
from mindspore.communication import management
|
from mindspore.communication import management
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
|
@ -837,9 +836,9 @@ class LayerNorm(Cell):
|
||||||
gamma_init, normalized_shape), name="gamma")
|
gamma_init, normalized_shape), name="gamma")
|
||||||
self.beta = Parameter(initializer(
|
self.beta = Parameter(initializer(
|
||||||
beta_init, normalized_shape), name="beta")
|
beta_init, normalized_shape), name="beta")
|
||||||
self.layer_norm = _selected_ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
|
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
|
||||||
begin_params_axis=self.begin_params_axis,
|
begin_params_axis=self.begin_params_axis,
|
||||||
epsilon=self.epsilon)
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
def construct(self, input_x):
|
def construct(self, input_x):
|
||||||
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||||
|
|
|
@ -21,7 +21,6 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
from mindspore._checkparam import Validator as validator
|
from mindspore._checkparam import Validator as validator
|
||||||
|
@ -48,7 +47,7 @@ class _Loss(Cell):
|
||||||
if reduction == 'none':
|
if reduction == 'none':
|
||||||
self.reduce = False
|
self.reduce = False
|
||||||
|
|
||||||
self.reduce_mean = _selected_ops.ReduceMean()
|
self.reduce_mean = P.ReduceMean()
|
||||||
self.reduce_sum = P.ReduceSum()
|
self.reduce_sum = P.ReduceSum()
|
||||||
self.mul = P.Mul()
|
self.mul = P.Mul()
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
@ -381,7 +380,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||||
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
||||||
self.sparse = validator.check_bool(sparse, "sparse")
|
self.sparse = validator.check_bool(sparse, "sparse")
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
|
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||||
self.one_hot = P.OneHot()
|
self.one_hot = P.OneHot()
|
||||||
self.on_value = Tensor(1.0, mstype.float32)
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
self.off_value = Tensor(0., mstype.float32)
|
self.off_value = Tensor(0., mstype.float32)
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""momentum"""
|
"""momentum"""
|
||||||
from mindspore.ops import functional as F, composite as C, operations as P
|
from mindspore.ops import functional as F, composite as C, operations as P
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
|
@ -157,7 +156,7 @@ class Momentum(Optimizer):
|
||||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||||
self.hyper_map = C.HyperMap()
|
self.hyper_map = C.HyperMap()
|
||||||
self.opt = _selected_ops.ApplyMomentum(use_nesterov=self.use_nesterov)
|
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||||
|
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
params = self.params
|
params = self.params
|
||||||
|
|
|
@ -19,7 +19,6 @@ from functools import reduce
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.ops import _selected_grad_ops as SG
|
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
|
@ -590,7 +589,7 @@ def get_bprop_expm1(self):
|
||||||
@bprop_getters.register(P.Minimum)
|
@bprop_getters.register(P.Minimum)
|
||||||
def get_bprop_minimum(self):
|
def get_bprop_minimum(self):
|
||||||
"""Grad definition for `Minimum` operation."""
|
"""Grad definition for `Minimum` operation."""
|
||||||
input_grad = SG.MinimumGrad()
|
input_grad = G.MinimumGrad()
|
||||||
|
|
||||||
def bprop(x, y, out, dout):
|
def bprop(x, y, out, dout):
|
||||||
dx, dy = input_grad(x, y, dout)
|
dx, dy = input_grad(x, y, dout)
|
||||||
|
@ -602,7 +601,7 @@ def get_bprop_minimum(self):
|
||||||
@bprop_getters.register(P.Maximum)
|
@bprop_getters.register(P.Maximum)
|
||||||
def get_bprop_maximum(self):
|
def get_bprop_maximum(self):
|
||||||
"""Grad definition for `Maximum` operation."""
|
"""Grad definition for `Maximum` operation."""
|
||||||
input_grad = SG.MaximumGrad()
|
input_grad = G.MaximumGrad()
|
||||||
|
|
||||||
def bprop(x, y, out, dout):
|
def bprop(x, y, out, dout):
|
||||||
dx, dy = input_grad(x, y, dout)
|
dx, dy = input_grad(x, y, dout)
|
||||||
|
@ -1107,7 +1106,7 @@ def get_bprop_cosh(self):
|
||||||
@bprop_getters.register(P.Abs)
|
@bprop_getters.register(P.Abs)
|
||||||
def get_bprop_abs(self):
|
def get_bprop_abs(self):
|
||||||
"""Grad definition for `Abs` operation."""
|
"""Grad definition for `Abs` operation."""
|
||||||
abs_grad = SG.AbsGrad()
|
abs_grad = G.AbsGrad()
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = abs_grad(x, dout)
|
dx = abs_grad(x, dout)
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
"""Define the grad rules of neural network related operations."""
|
"""Define the grad rules of neural network related operations."""
|
||||||
import os
|
import os
|
||||||
from mindspore.ops import _selected_grad_ops as SG
|
|
||||||
from mindspore.ops.primitive import constexpr
|
from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.ops.operations import nn_ops as nps
|
from mindspore.ops.operations import nn_ops as nps
|
||||||
|
@ -34,7 +33,7 @@ env_force_bprop_seq = os.getenv("ENV_FORCE_BPROP_SEQ")
|
||||||
@bprop_getters.register(P.BiasAdd)
|
@bprop_getters.register(P.BiasAdd)
|
||||||
def get_bprop_bias_add(self):
|
def get_bprop_bias_add(self):
|
||||||
"""Grad definition for `BiasAdd` operation."""
|
"""Grad definition for `BiasAdd` operation."""
|
||||||
bias_grad = SG.BiasAddGrad(self.data_format)
|
bias_grad = G.BiasAddGrad(self.data_format)
|
||||||
|
|
||||||
def bprop(x, w, out, dout):
|
def bprop(x, w, out, dout):
|
||||||
return dout, bias_grad(dout)
|
return dout, bias_grad(dout)
|
||||||
|
@ -341,7 +340,7 @@ def get_bprop_dropout_do_mask(self):
|
||||||
def get_bprop_mish(self):
|
def get_bprop_mish(self):
|
||||||
"""Grad definition for `Mish` operation."""
|
"""Grad definition for `Mish` operation."""
|
||||||
tanh = P.Tanh()
|
tanh = P.Tanh()
|
||||||
tanh_grad = SG.TanhGrad()
|
tanh_grad = G.TanhGrad()
|
||||||
softplus = P.Softplus()
|
softplus = P.Softplus()
|
||||||
softplus_grad = G.SoftplusGrad()
|
softplus_grad = G.SoftplusGrad()
|
||||||
|
|
||||||
|
@ -580,7 +579,7 @@ def get_bprop_softsign(self):
|
||||||
@bprop_getters.register(P.Tanh)
|
@bprop_getters.register(P.Tanh)
|
||||||
def get_bprop_tanh(self):
|
def get_bprop_tanh(self):
|
||||||
"""Grad definition for `Tanh` operation."""
|
"""Grad definition for `Tanh` operation."""
|
||||||
tanh_grad = SG.TanhGrad()
|
tanh_grad = G.TanhGrad()
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = tanh_grad(out, dout)
|
dx = tanh_grad(out, dout)
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
""" resolved grad ops """
|
|
||||||
from mindspore.ops.op_selector import new_ops_selector
|
|
||||||
|
|
||||||
op_selector = new_ops_selector(
|
|
||||||
"mindspore.ops.operations._grad_ops", "mindspore.nn._graph_kernels")
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class MaximumGrad:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class MinimumGrad:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class AbsGrad:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class BiasAddGrad:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class TanhGrad:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
|
@ -1,120 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
""" resolve ops """
|
|
||||||
from mindspore.ops.op_selector import new_ops_selector
|
|
||||||
|
|
||||||
op_selector = new_ops_selector(
|
|
||||||
"mindspore.ops.operations", "mindspore.nn._graph_kernels")
|
|
||||||
opt_selector = new_ops_selector(
|
|
||||||
"mindspore.nn.optim", "mindspore.nn._graph_kernels")
|
|
||||||
nn_selector = new_ops_selector(
|
|
||||||
"mindspore.nn", "mindspore.nn._graph_kernels")
|
|
||||||
|
|
||||||
|
|
||||||
@nn_selector
|
|
||||||
class BatchNorm2d:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class ReLU:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class ReduceMean:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class BiasAdd:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class ApplyMomentum:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class SoftmaxCrossEntropyWithLogits:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LogSoftmax:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class Tanh:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class GeLU:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class FastGeLU:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LayerNorm:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class Softmax:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LambUpdateWithLR:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LambNextMV:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LambApplyOptimizerAssign:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@op_selector
|
|
||||||
class LambApplyWeightAssign:
|
|
||||||
def __call__(self, *args):
|
|
||||||
pass
|
|
|
@ -20,6 +20,7 @@ which can be used to control the switch of op type: GraphKernel or Primitive.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore.common._decorator import deprecated
|
||||||
|
|
||||||
|
|
||||||
class _OpSelector:
|
class _OpSelector:
|
||||||
|
@ -70,6 +71,7 @@ class _OpSelector:
|
||||||
return op(*args, **kwargs)
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("1.3", "basic Primitive", False)
|
||||||
def new_ops_selector(primitive_pkg, graph_kernel_pkg):
|
def new_ops_selector(primitive_pkg, graph_kernel_pkg):
|
||||||
"""
|
"""
|
||||||
A factory method to return an op selector
|
A factory method to return an op selector
|
||||||
|
@ -83,6 +85,8 @@ def new_ops_selector(primitive_pkg, graph_kernel_pkg):
|
||||||
The order of the highest priority to lowest priority is (1), (2), (3)
|
The order of the highest priority to lowest priority is (1), (2), (3)
|
||||||
If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
|
If the GraphKernel switch is off, then op_type will always be PRIMITIVE.
|
||||||
|
|
||||||
|
The user-defined GraphKernel Cell is deprecated, this interface will be removed in a future version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
primitive_pkg (str): primitive op's package name
|
primitive_pkg (str): primitive op's package name
|
||||||
graph_kernel_pkg (str): graph kernel op's package name
|
graph_kernel_pkg (str): graph kernel op's package name
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
"""Other operators."""
|
"""Other operators."""
|
||||||
import functools
|
import functools
|
||||||
from mindspore.common import monad
|
from mindspore.common import monad
|
||||||
|
from mindspore.common._decorator import deprecated
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
from ..._checkparam import Validator as validator, Rel
|
from ..._checkparam import Validator as validator, Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
|
@ -84,6 +85,8 @@ class InplaceAssign(PrimitiveWithInfer):
|
||||||
Inplace assign `Parameter` with a value.
|
Inplace assign `Parameter` with a value.
|
||||||
This primitive can only use in graph kernel.
|
This primitive can only use in graph kernel.
|
||||||
|
|
||||||
|
InplaceAssign is deprecated from version 1.3 and will be removed in a future version, use Assign instead.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **variable** (Parameter) - The `Parameter`.
|
- **variable** (Parameter) - The `Parameter`.
|
||||||
- **value** (Tensor) - The value to be assigned.
|
- **value** (Tensor) - The value to be assigned.
|
||||||
|
@ -110,7 +113,8 @@ class InplaceAssign(PrimitiveWithInfer):
|
||||||
>>> net = Net()
|
>>> net = Net()
|
||||||
>>> output = net(x)
|
>>> output = net(x)
|
||||||
>>> print(output)
|
>>> print(output)
|
||||||
"""
|
"""
|
||||||
|
@deprecated("1.3", "Assign", False)
|
||||||
@ prim_attr_register
|
@ prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
|
||||||
|
|
|
@ -26,7 +26,6 @@ from mindspore.nn.loss.loss import _Loss
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore.ops import _selected_ops
|
|
||||||
from mindspore.parallel._utils import _reset_op_id
|
from mindspore.parallel._utils import _reset_op_id
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
@ -76,7 +75,7 @@ class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||||
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction)
|
||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
self.softmax_cross_entropy = _selected_ops.SoftmaxCrossEntropyWithLogits()
|
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||||
self.one_hot = P.OneHot()
|
self.one_hot = P.OneHot()
|
||||||
self.on_value = Tensor(1.0, mstype.float32)
|
self.on_value = Tensor(1.0, mstype.float32)
|
||||||
self.off_value = Tensor(0., mstype.float32)
|
self.off_value = Tensor(0., mstype.float32)
|
||||||
|
|
Loading…
Reference in New Issue