forked from OSSInnovation/mindspore
remove the useless transdata and cast connected with control depend
This commit is contained in:
parent
eaaacfea4c
commit
0ac5911910
|
@ -62,12 +62,12 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
|
||||
void ValidateAbstract(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(WARNING) << "Node to validate is invalid";
|
||||
MS_LOG(DEBUG) << "Node to validate is invalid";
|
||||
return;
|
||||
}
|
||||
AbstractBasePtr ptrBase = node->abstract();
|
||||
if (ptrBase == nullptr) {
|
||||
MS_LOG(WARNING) << "Abstract is null in node: " << node->DebugString();
|
||||
MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString();
|
||||
return;
|
||||
}
|
||||
if (ptrBase->isa<AbstractClass>() || ptrBase->isa<AbstractJTagged>()) {
|
||||
|
|
|
@ -61,16 +61,14 @@ bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type,
|
|||
|
||||
bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node,
|
||||
size_t *cast_index) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Check whether the cast node is used for input by only one another node.
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(node) == manager->node_users().end() || manager->node_users()[node].size() != 1) {
|
||||
auto output_node_list = GetRealNodeUsedList(graph, node);
|
||||
MS_EXCEPTION_IF_NULL(output_node_list);
|
||||
if (output_node_list->size() != 1) {
|
||||
return false;
|
||||
}
|
||||
*next_node = manager->node_users()[node].begin()->first;
|
||||
*cast_index = IntToSize(manager->node_users()[node].begin()->second - 1);
|
||||
auto node_pair = output_node_list->at(0);
|
||||
*next_node = node_pair.first;
|
||||
*cast_index = node_pair.second - 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -148,7 +146,10 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
|
|||
if (alternative_kernel_info == kernel_info_list.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_op_name;
|
||||
auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node);
|
||||
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString()
|
||||
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
||||
<< (*alternative_kernel_info)->ToString();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get());
|
||||
if (node->inputs().size() < kCastInputNum) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(node);
|
||||
|
@ -217,8 +218,11 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod
|
|||
if (kernel_info_it == kernel_info_list.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op);
|
||||
MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString()
|
||||
<< "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info"
|
||||
<< (*kernel_info_it)->ToString();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get());
|
||||
|
||||
auto prior_name = AnfAlgo::GetCNodeName(prior_op);
|
||||
if (prior_name == kFive2FourOpName) {
|
||||
AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <deque>
|
||||
#include "utils/utils.h"
|
||||
#include "utils/base_ref.h"
|
||||
|
@ -472,15 +473,38 @@ void RemoveNopNode(session::KernelGraph *const graph) {
|
|||
}
|
||||
}
|
||||
|
||||
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node) {
|
||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> output_node_list =
|
||||
std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (manager->node_users().find(node) == manager->node_users().end()) {
|
||||
auto iter = manager->node_users().find(node);
|
||||
if (iter == manager->node_users().end()) {
|
||||
MS_LOG(EXCEPTION) << "node has no output in manager";
|
||||
}
|
||||
return manager->node_users()[node].size() > 1;
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
|
||||
output_info.second == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
output_node_list->push_back(output_info);
|
||||
}
|
||||
return output_node_list;
|
||||
}
|
||||
|
||||
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto output_node_list = GetRealNodeUsedList(graph, node);
|
||||
MS_EXCEPTION_IF_NULL(output_node_list);
|
||||
return output_node_list->size() > 1;
|
||||
}
|
||||
|
||||
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
#include "session/kernel_graph.h"
|
||||
|
@ -160,6 +161,9 @@ void RemoveNopNode(session::KernelGraph *const graph);
|
|||
|
||||
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
|
||||
|
||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||
const AnfNodePtr &node);
|
||||
|
||||
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
|
||||
|
|
|
@ -44,11 +44,11 @@ AnfNodePtr GetReplaceNode(const AnfNodePtr &node) {
|
|||
return cnode->input(kSingleInputIndex);
|
||||
}
|
||||
|
||||
bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) {
|
||||
return false;
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs;
|
||||
bool need_update = false;
|
||||
|
@ -75,17 +75,16 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|||
auto manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(cnode, new_make_tuple);
|
||||
return new_make_tuple;
|
||||
}
|
||||
return true;
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef OptimizeDependence::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>("X");
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VarPtr Y = std::make_shared<Var>("Y");
|
||||
MS_EXCEPTION_IF_NULL(Y);
|
||||
return VectorRef({prim::kPrimDepend, X, Y});
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
|
@ -95,29 +94,50 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con
|
|||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(node) != prim::kPrimControlDepend->name() &&
|
||||
AnfAlgo::GetCNodeName(node) != prim::kPrimDepend->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
size_t index = 0;
|
||||
auto depend_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
CheckCNodeInputSize(depend_cnode, kDependInputNum);
|
||||
auto replacing_node = depend_cnode->input(kDependInputNum - 1);
|
||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||
if (!replacing_node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)};
|
||||
if (AnfAlgo::GetCNodeName(node) == prim::kPrimDepend->name()) {
|
||||
index = 1;
|
||||
new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend));
|
||||
}
|
||||
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
||||
// Deal with the make_tuple with TransData or Cast inputs.
|
||||
if (ReplaceMakeTuple(func_graph, replacing_cnode)) {
|
||||
return nullptr;
|
||||
if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) {
|
||||
MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got "
|
||||
<< AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString();
|
||||
}
|
||||
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
||||
if (replace_node == nullptr) {
|
||||
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString();
|
||||
return nullptr;
|
||||
|
||||
while (index < AnfAlgo::GetInputTensorNum(depend_cnode)) {
|
||||
auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index);
|
||||
++index;
|
||||
MS_EXCEPTION_IF_NULL(replacing_node);
|
||||
if (!replacing_node->isa<CNode>()) {
|
||||
new_depend_inputs.push_back(replacing_node);
|
||||
continue;
|
||||
}
|
||||
auto replacing_cnode = replacing_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(replacing_cnode);
|
||||
// Deal with the make_tuple with TransData or Cast inputs.
|
||||
auto make_tuple_replace_node = ReplaceMakeTuple(func_graph, replacing_cnode);
|
||||
if (make_tuple_replace_node != nullptr) {
|
||||
new_depend_inputs.push_back(make_tuple_replace_node);
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr replace_node = GetReplaceNode(replacing_cnode);
|
||||
if (replace_node == nullptr) {
|
||||
new_depend_inputs.push_back(replacing_node);
|
||||
MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: "
|
||||
<< node->DebugString();
|
||||
continue;
|
||||
}
|
||||
new_depend_inputs.push_back(replace_node);
|
||||
}
|
||||
std::vector<AnfNodePtr> new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex),
|
||||
depend_cnode->input(kRealInputIndexInDepend), replace_node};
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
CNodePtr new_depend;
|
||||
CNodePtr new_depend = nullptr;
|
||||
if (kernel_graph == nullptr) {
|
||||
new_depend = func_graph->NewCNode(new_depend_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_depend);
|
||||
|
|
|
@ -171,18 +171,18 @@ def test_bert_tdt():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
expect_loss_value = [12.207198, 11.980881, 11.984844, 11.879381, 11.832978, 12.411333, 12.009284,
|
||||
12.621277, 12.223178, 12.427385]
|
||||
expect_loss_value = [12.207198, 11.865665, 11.828972, 11.827378, 11.821808, 12.408042, 12.00606,
|
||||
12.621794, 12.223485, 12.427612]
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)
|
||||
|
||||
overflow = np.array(callback.overflow_list)
|
||||
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
|
||||
expect_overflow = [False, False, False, True, False, False, False, True, False, False]
|
||||
print("overflow: {}".format(overflow))
|
||||
assert (overflow == expect_overflow).all()
|
||||
|
||||
loss_scale = np.array(callback.lossscale_list)
|
||||
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
|
||||
expect_loss_scale = [65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0, 131072.0, 65536.0, 65536.0, 65536.0]
|
||||
print("loss scale: {}".format(loss_scale))
|
||||
assert np.allclose(loss_scale, expect_loss_scale, 0, 0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue