codeclean
This commit is contained in:
parent
7513363f30
commit
d8addd93b7
|
@ -39,6 +39,7 @@ bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) const {
|
|||
}
|
||||
|
||||
int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (HasFusionIdAttr(node)) {
|
||||
return common::AnfAlgo::GetNodeAttr<int64_t>(node, kAttrFusionId);
|
||||
}
|
||||
|
@ -46,6 +47,7 @@ int64_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) const {
|
|||
}
|
||||
|
||||
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int64_t id) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
ValuePtr fusion_id_v = MakeValue(id);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
||||
}
|
||||
|
|
|
@ -43,10 +43,13 @@ constexpr size_t kType64Len = 8;
|
|||
constexpr auto kNopNodeRealInputIndex = 1;
|
||||
|
||||
void UpdateDumpFlagAndDebugInfo(const CNodePtr &node, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<AnfNodePtr> orig_real_cnodes;
|
||||
for (auto &orig_node : orig_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
if (AnfUtils::IsRealCNodeKernel(orig_node)) {
|
||||
auto orig_cnode = orig_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(orig_cnode);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrDump, orig_cnode)) {
|
||||
common::AnfAlgo::CopyNodeAttr(kAttrDump, orig_cnode, node);
|
||||
}
|
||||
|
@ -160,6 +163,7 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y
|
|||
|
||||
const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(transop_cnode);
|
||||
|
@ -620,6 +624,8 @@ ValueNodePtr CreateShapeValueNode(const FuncGraphPtr &func_graph, const std::vec
|
|||
}
|
||||
|
||||
CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, const CNodePtr &node, const bool is_input) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<AnfNodePtr> new_cast_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name()))};
|
||||
BaseShapePtr shape;
|
||||
if (is_input) {
|
||||
|
@ -640,6 +646,8 @@ CNodePtr AddCastNode(const FuncGraphPtr &func_graph, const TypeId dst_type, cons
|
|||
|
||||
AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &new_node_inputs,
|
||||
const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto new_node = graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
||||
|
@ -840,6 +848,7 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
|||
|
||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &input_abstract) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
|
||||
if (dynamic_inputs_list == nullptr) {
|
||||
return input_abstract;
|
||||
|
|
|
@ -54,6 +54,8 @@ void AddOutputAndCallerToMap(const CNodePtr &cnode, mindspore::HashMap<AnfNodePt
|
|||
|
||||
void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspore::HashSet<AnfNodePtr> *seen_node) {
|
||||
MS_EXCEPTION_IF_NULL(seen_node);
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (old_node->isa<CNode>() && new_node->isa<CNode>() &&
|
||||
(common::AnfAlgo::GetCNodeName(old_node) == common::AnfAlgo::GetCNodeName(new_node))) {
|
||||
(void)seen_node->insert(new_node);
|
||||
|
|
|
@ -33,6 +33,8 @@ mindspore::HashMap<std::string, mindspore::HashSet<std::string>> MarkOp{
|
|||
{"LSTM", {"LSTMGradWeight", "LSTMGrad", "LSTMGradData"}}};
|
||||
|
||||
bool CheckOP(const FuncGraphManagerPtr &manager, const AnfNodePtr &cnode, const mindspore::HashSet<std::string> &set) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (const auto &node_index : manager->node_users()[cnode]) {
|
||||
auto output = node_index.first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
|
|
|
@ -81,12 +81,14 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
|||
if (recompute_min_fusion_id <= unrecompute_max_fusion_id) {
|
||||
MS_LOG(WARNING) << "Increase the duplicated allgather fusion id";
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_first_fusion_allgathers) {
|
||||
MS_EXCEPTION_IF_NULL(adjust_node);
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id =
|
||||
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
||||
common::AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(destination_fusion_id), adjust_node);
|
||||
}
|
||||
for (auto &adjust_node : parallel_optimizer_recompute_allgathers) {
|
||||
MS_EXCEPTION_IF_NULL(adjust_node);
|
||||
int64_t current_fusion_id = common::AnfAlgo::GetNodeAttr<int64_t>(adjust_node, kAttrFusion);
|
||||
int64_t destination_fusion_id =
|
||||
(kFusionGap + current_fusion_id + unrecompute_max_fusion_id) - recompute_min_fusion_id;
|
||||
|
@ -97,19 +99,26 @@ void AdjustDependForParallelOptimizerRecomputeAllGather::IncreaseAllgatherFusion
|
|||
|
||||
bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
||||
const FuncGraphPtr &graph, const std::vector<AnfNodePtr> ¶llel_optimizer_recompute_allgathers) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool changed = false;
|
||||
for (auto &node : parallel_optimizer_recompute_allgathers) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto depend_node = common::AnfAlgo::GetInputNode(cnode, 0);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto set_edge_node = node;
|
||||
if (IsPrimitiveCNode(depend_node, prim::kPrimTensorMove)) {
|
||||
auto tensormove_cnode = depend_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensormove_cnode);
|
||||
set_edge_node = depend_node;
|
||||
depend_node = common::AnfAlgo::GetInputNode(tensormove_cnode, 0);
|
||||
}
|
||||
if (IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
|
||||
auto depend_cnode = depend_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||
for (auto &node_pair : allgather_node_set) {
|
||||
auto allgather_next_node = node_pair.first;
|
||||
|
@ -128,8 +137,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::AdjustAllgatherDepend(
|
|||
} else if (IsPrimitiveCNode(depend_node, prim::kPrimCast) &&
|
||||
IsPrimitiveCNode(common::AnfAlgo::GetInputNode(depend_node->cast<CNodePtr>(), 0), prim::kPrimDepend)) {
|
||||
auto cast_cnode = depend_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_cnode);
|
||||
auto cast_depend_node = common::AnfAlgo::GetInputNode(cast_cnode, 0);
|
||||
auto cast_depend_cnode = cast_depend_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_depend_cnode);
|
||||
AnfNodeIndexSet allgather_node_set = manager->node_users()[cnode];
|
||||
for (auto &node_pair : allgather_node_set) {
|
||||
auto allgather_next_node = node_pair.first;
|
||||
|
|
|
@ -69,6 +69,7 @@ AnfNodePtr ClipByNormFission::CreateCNodeBase(const FuncGraphPtr &func_graph, co
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(std::make_shared<Primitive>(op_name))};
|
||||
for (const auto &inp : inps) {
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
(void)new_node_inputs.emplace_back(inp);
|
||||
}
|
||||
auto new_node = NewCNode(new_node_inputs, func_graph);
|
||||
|
|
Loading…
Reference in New Issue