codeclean

This commit is contained in:
ttudu 2022-11-01 18:51:19 +08:00
parent 7513363f30
commit d8addd93b7
6 changed files with 27 additions and 0 deletions

View File

@ -39,6 +39,7 @@ bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) const {
} }
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const { int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
if (HasFusionIdAttr(node)) { if (HasFusionIdAttr(node)) {
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId); return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
} }
@ -46,6 +47,7 @@ int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
} }
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const { void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const {
MS_EXCEPTION_IF_NULL(node);
ValuePtr fusion_id_v = MakeValue(id); ValuePtr fusion_id_v = MakeValue(id);
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
} }

View File

@ -43,10 +43,13 @@ constexpr size_t kType64Len = 8;
constexpr auto kNopNodeRealInputIndex = 1; constexpr auto kNopNodeRealInputIndex = 1;
void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) { void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> orig_real_cnodes; std::vector<AnfNodePtr> orig_real_cnodes;
for (auto &orig_node : orig_nodes) { for (auto &orig_node : orig_nodes) {
MS_EXCEPTION_IF_NULL(orig_node);
if (AnfUtils::IsRealCNodeKernel(orig_node)) { if (AnfUtils::IsRealCNodeKernel(orig_node)) {
auto orig_cnode = orig_node->cast<CNodePtr>(); auto orig_cnode = orig_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(orig_cnode);
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) { if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node); common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
} }
@ -160,6 +163,7 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum); auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
MS_EXCEPTION_IF_NULL(transop_cnode); MS_EXCEPTION_IF_NULL(transop_cnode);
@ -620,6 +624,8 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
} }
CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) { CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))}; std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
BaseShapePtr shape; BaseShapePtr shape;
if (is_input) { if (is_input) {
@ -640,6 +646,8 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs, AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
const AnfNodePtr &node) { const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto new_node = graph->NewCNode(new_node_inputs); auto new_node = graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
@ -840,6 +848,7 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
const AbstractBasePtrList &input_abstract) { const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) { if (dynamic_inputs_list == nullptr) {
return input_abstract; return input_abstract;

View File

@ -54,6 +54,8 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) { void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
MS_EXCEPTION_IF_NULL(seen_node); MS_EXCEPTION_IF_NULL(seen_node);
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);
if (old_node->isa<CNode>() && new_node->isa<CNode>() && if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
(common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) { (common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
(void)seen_node->insert(new_node); (void)seen_node->insert(new_node);

View File

@ -33,6 +33,8 @@ mindspore::HashMap<std::string, mindspore::HashSet<std::string>> MarkOp{
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}}; {"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) { bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(cnode);
for (const auto &node_index : manager->node_users()[cnode]) { for (const auto &node_index : manager->node_users()[cnode]) {
auto output = node_index.first; auto output = node_index.first;
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);

View File

@ -81,12 +81,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) { if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id"; MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) { for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
MS_EXCEPTION_IF_NULL(adjust_node);
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion); int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = int64_t destination_fusion_id =
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id; (kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node); common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
} }
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) { for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
MS_EXCEPTION_IF_NULL(adjust_node);
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion); int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id = int64_t destination_fusion_id =
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id; (kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
@ -97,19 +99,26 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend( bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &parallel_optimizer_recompute_allgathers) const { const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &parallel_optimizer_recompute_allgathers) const {
MS_EXCEPTION_IF_NULL(graph);
FuncGraphManagerPtr manager = graph->manager(); FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
bool changed = false; bool changed = false;
for (auto &node : parallel_optimizer_recompute_allgathers) { for (auto &node : parallel_optimizer_recompute_allgathers) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0); auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
MS_EXCEPTION_IF_NULL(depend_node);
auto set_edge_node = node; auto set_edge_node = node;
if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) { if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) {
auto tensormove_cnode = depend_node->cast<CNodePtr>(); auto tensormove_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tensormove_cnode);
set_edge_node = depend_node; set_edge_node = depend_node;
depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0); depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0);
} }
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
auto depend_cnode = depend_node->cast<CNodePtr>(); auto depend_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode]; AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
for (auto &node_pair : allgather_node_set) { for (auto &node_pair : allgather_node_set) {
auto allgather_next_node = node_pair.first; auto allgather_next_node = node_pair.first;
@ -128,8 +137,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) && } else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) { IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
auto cast_cnode = depend_node->cast<CNodePtr>(); auto cast_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_cnode);
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0); auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>(); auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_depend_cnode);
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode]; AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
for (auto &node_pair : allgather_node_set) { for (auto &node_pair : allgather_node_set) {
auto allgather_next_node = node_pair.first; auto allgather_next_node = node_pair.first;

View File

@ -69,6 +69,7 @@ AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, co
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))}; std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))};
for (const auto &inp : inps) { for (const auto &inp : inps) {
MS_EXCEPTION_IF_NULL(inp);
(void)new_node_inputs.emplace_back(inp); (void)new_node_inputs.emplace_back(inp);
} }
auto new_node = NewCNode(new_node_inputs, func_graph); auto new_node = NewCNode(new_node_inputs, func_graph);