codeclean

This commit is contained in:
ttudu 2022-11-01 18:51:19 +08:00
parent 7513363f30
commit d8addd93b7
6 changed files with 27 additions and 0 deletions

View File

@ -39,6 +39,7 @@ bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) const {
}
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
if (HasFusionIdAttr(node)) {
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
}
@ -46,6 +47,7 @@ int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
}
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const {
MS_EXCEPTION_IF_NULL(node);
ValuePtr fusion_id_v = MakeValue(id);
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
}

View File

@ -43,10 +43,13 @@ constexpr size_t kType64Len = 8;
constexpr auto kNopNodeRealInputIndex = 1;
void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> orig_real_cnodes;
for (auto &orig_node : orig_nodes) {
MS_EXCEPTION_IF_NULL(orig_node);
if (AnfUtils::IsRealCNodeKernel(orig_node)) {
auto orig_cnode = orig_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(orig_cnode);
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
}
@ -160,6 +163,7 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
MS_EXCEPTION_IF_NULL(transop_cnode);
@ -620,6 +624,8 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
}
CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
BaseShapePtr shape;
if (is_input) {
@ -640,6 +646,8 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto new_node = graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_node);
@ -840,6 +848,7 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) {
return input_abstract;

View File

@ -54,6 +54,8 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
MS_EXCEPTION_IF_NULL(seen_node);
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);
if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
(common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
(void)seen_node->insert(new_node);

View File

@ -33,6 +33,8 @@ mindspore::HashMap<std::string, mindspore::HashSet<std::string>> MarkOp{
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) {
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(cnode);
for (const auto &node_index : manager->node_users()[cnode]) {
auto output = node_index.first;
MS_EXCEPTION_IF_NULL(output);

View File

@ -81,12 +81,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
MS_EXCEPTION_IF_NULL(adjust_node);
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id =
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
}
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
MS_EXCEPTION_IF_NULL(adjust_node);
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
int64_t destination_fusion_id =
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
@ -97,19 +99,26 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &parallel_optimizer_recompute_allgathers) const {
MS_EXCEPTION_IF_NULL(graph);
FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
bool changed = false;
for (auto &node : parallel_optimizer_recompute_allgathers) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
MS_EXCEPTION_IF_NULL(depend_node);
auto set_edge_node = node;
if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) {
auto tensormove_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tensormove_cnode);
set_edge_node = depend_node;
depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0);
}
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
auto depend_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
for (auto &node_pair : allgather_node_set) {
auto allgather_next_node = node_pair.first;
@ -128,8 +137,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
auto cast_cnode = depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_cnode);
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_depend_cnode);
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
for (auto &node_pair : allgather_node_set) {
auto allgather_next_node = node_pair.first;

View File

@ -69,6 +69,7 @@ AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, co
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))};
for (const auto &inp : inps) {
MS_EXCEPTION_IF_NULL(inp);
(void)new_node_inputs.emplace_back(inp);
}
auto new_node = NewCNode(new_node_inputs, func_graph);