codeclean
This commit is contained in:
parent
7513363f30
commit
d8addd93b7
|
@ -39,6 +39,7 @@ bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
|
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (HasFusionIdAttr(node)) {
|
if (HasFusionIdAttr(node)) {
|
||||||
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
|
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
|
||||||
}
|
}
|
||||||
|
@ -46,6 +47,7 @@ int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const {
|
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
ValuePtr fusion_id_v = MakeValue(id);
|
ValuePtr fusion_id_v = MakeValue(id);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,10 +43,13 @@ constexpr size_t kType64Len = 8;
|
||||||
constexpr auto kNopNodeRealInputIndex = 1;
|
constexpr auto kNopNodeRealInputIndex = 1;
|
||||||
|
|
||||||
void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
|
void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<AnfNodePtr> orig_real_cnodes;
|
std::vector<AnfNodePtr> orig_real_cnodes;
|
||||||
for (auto &orig_node : orig_nodes) {
|
for (auto &orig_node : orig_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(orig_node);
|
||||||
if (AnfUtils::IsRealCNodeKernel(orig_node)) {
|
if (AnfUtils::IsRealCNodeKernel(orig_node)) {
|
||||||
auto orig_cnode = orig_node->cast<CNodePtr>();
|
auto orig_cnode = orig_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(orig_cnode);
|
||||||
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
||||||
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
||||||
}
|
}
|
||||||
|
@ -160,6 +163,7 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
|
||||||
|
|
||||||
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
|
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
|
||||||
MS_EXCEPTION_IF_NULL(transop_cnode);
|
MS_EXCEPTION_IF_NULL(transop_cnode);
|
||||||
|
@ -620,6 +624,8 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) {
|
CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
|
std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
|
||||||
BaseShapePtr shape;
|
BaseShapePtr shape;
|
||||||
if (is_input) {
|
if (is_input) {
|
||||||
|
@ -640,6 +646,8 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
|
||||||
|
|
||||||
AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
|
AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
|
||||||
const AnfNodePtr &node) {
|
const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto new_node = graph->NewCNode(new_node_inputs);
|
auto new_node = graph->NewCNode(new_node_inputs);
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
|
|
||||||
|
@ -840,6 +848,7 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
||||||
|
|
||||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
|
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &input_abstract) {
|
const AbstractBasePtrList &input_abstract) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
|
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
|
||||||
if (dynamic_inputs_list == nullptr) {
|
if (dynamic_inputs_list == nullptr) {
|
||||||
return input_abstract;
|
return input_abstract;
|
||||||
|
|
|
@ -54,6 +54,8 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
|
||||||
|
|
||||||
void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
|
void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
|
||||||
MS_EXCEPTION_IF_NULL(seen_node);
|
MS_EXCEPTION_IF_NULL(seen_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(old_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
|
if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
|
||||||
(common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
|
(common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
|
||||||
(void)seen_node->insert(new_node);
|
(void)seen_node->insert(new_node);
|
||||||
|
|
|
@ -33,6 +33,8 @@ mindspore::HashMap<std::string, mindspore::HashSet<std::string>> MarkOp{
|
||||||
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
|
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
|
||||||
|
|
||||||
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) {
|
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
for (const auto &node_index : manager->node_users()[cnode]) {
|
for (const auto &node_index : manager->node_users()[cnode]) {
|
||||||
auto output = node_index.first;
|
auto output = node_index.first;
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
|
|
@ -81,12 +81,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
||||||
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
|
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
|
||||||
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
|
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
|
||||||
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
|
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
|
||||||
|
MS_EXCEPTION_IF_NULL(adjust_node);
|
||||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||||
int64_t destination_fusion_id =
|
int64_t destination_fusion_id =
|
||||||
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||||
}
|
}
|
||||||
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
|
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
|
||||||
|
MS_EXCEPTION_IF_NULL(adjust_node);
|
||||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||||
int64_t destination_fusion_id =
|
int64_t destination_fusion_id =
|
||||||
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
||||||
|
@ -97,19 +99,26 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
||||||
|
|
||||||
bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
||||||
const FuncGraphPtr &graph, const std::vector<AnfNodePtr> ¶llel_optimizer_recompute_allgathers) const {
|
const FuncGraphPtr &graph, const std::vector<AnfNodePtr> ¶llel_optimizer_recompute_allgathers) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
FuncGraphManagerPtr manager = graph->manager();
|
FuncGraphManagerPtr manager = graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (auto &node : parallel_optimizer_recompute_allgathers) {
|
for (auto &node : parallel_optimizer_recompute_allgathers) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
|
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_node);
|
||||||
auto set_edge_node = node;
|
auto set_edge_node = node;
|
||||||
if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) {
|
if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) {
|
||||||
auto tensormove_cnode = depend_node->cast<CNodePtr>();
|
auto tensormove_cnode = depend_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensormove_cnode);
|
||||||
set_edge_node = depend_node;
|
set_edge_node = depend_node;
|
||||||
depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0);
|
depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0);
|
||||||
}
|
}
|
||||||
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
|
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
|
||||||
auto depend_cnode = depend_node->cast<CNodePtr>();
|
auto depend_cnode = depend_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||||
for (auto &node_pair : allgather_node_set) {
|
for (auto &node_pair : allgather_node_set) {
|
||||||
auto allgather_next_node = node_pair.first;
|
auto allgather_next_node = node_pair.first;
|
||||||
|
@ -128,8 +137,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
||||||
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
|
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
|
||||||
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
|
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
|
||||||
auto cast_cnode = depend_node->cast<CNodePtr>();
|
auto cast_cnode = depend_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cast_cnode);
|
||||||
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
|
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
|
||||||
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
|
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cast_depend_cnode);
|
||||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||||
for (auto &node_pair : allgather_node_set) {
|
for (auto &node_pair : allgather_node_set) {
|
||||||
auto allgather_next_node = node_pair.first;
|
auto allgather_next_node = node_pair.first;
|
||||||
|
|
|
@ -69,6 +69,7 @@ AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, co
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))};
|
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))};
|
||||||
for (const auto &inp : inps) {
|
for (const auto &inp : inps) {
|
||||||
|
MS_EXCEPTION_IF_NULL(inp);
|
||||||
(void)new_node_inputs.emplace_back(inp);
|
(void)new_node_inputs.emplace_back(inp);
|
||||||
}
|
}
|
||||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||||
|
|
Loading…
Reference in New Issue