From 54dc30f571d4de2b4b837d8390ca8a61508b31af Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Sun, 24 Apr 2022 15:19:13 +0800 Subject: [PATCH] Resolve the thor layer's parameter missing during grad operation. - The Parameter is used only in bprop procedure inserted by InsertGradientOf, but not used in forward construct(). --- mindspore/core/ir/func_graph_base.cc | 1 + mindspore/core/ir/manager.cc | 18 ++++++++++++++---- .../python/mindspore/nn/layer/thor_layer.py | 12 +++++++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mindspore/core/ir/func_graph_base.cc b/mindspore/core/ir/func_graph_base.cc index 0fc3c265673..d7b41164459 100644 --- a/mindspore/core/ir/func_graph_base.cc +++ b/mindspore/core/ir/func_graph_base.cc @@ -31,6 +31,7 @@ FuncGraphLoopBreaker &FuncGraphLoopBreaker::Inst() { static FuncGraphLoopBreaker mgr; return mgr; } + void FuncGraphLoopBreaker::BreakLoop() { MS_LOG(INFO) << "Size of not recycled graph before break loop is:" << func_set_.size(); std::list func_list; diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 821140b84bb..b6895e4067d 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -24,6 +24,7 @@ #include "ir/func_graph.h" #include "utils/convert_utils_base.h" #include "utils/counter.h" +#include "utils/trace_base.h" #include "mindspore/core/ops/core_ops.h" namespace mindspore { @@ -233,7 +234,9 @@ void FuncGraphManager::Init() { } FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); + if (fg == nullptr) { + MS_LOG(EXCEPTION) << "The parameter 'fg' should not be null."; + } MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); func_graph_parents_total_->Recompute(fg); MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString(); @@ -883,7 +886,14 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents( // Append all the fvs in fg. auto &fvs = fg->free_variables(); for (auto fv : fvs) { - parents->add(fv.first->func_graph()); + auto fv_node = fv.first; + auto fv_func_graph = fv_node->func_graph(); + if (fv_func_graph == nullptr) { + MS_LOG(INFO) << "Meet a FV '" << fv_node->DebugString() << "' whose func graph is null, during seeking for " + << fg->ToString() << "\nFV: " << trace::GetDebugInfo(fv_node->debug_info()); + continue; + } + parents->add(fv_func_graph); } // Search the fv in fg's child func graph. @@ -906,7 +916,8 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(fg); mindspore::HashMap seen_fgs; fg->seen_ = 1; - func_graph_parents_total_analysis_[fg].update(SeekParents(fg, &seen_fgs)); + auto parents = SeekParents(fg, &seen_fgs); + func_graph_parents_total_analysis_[fg].update(parents); fg->seen_ = 0; } @@ -920,7 +931,6 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { this->parent_analysis_[fg] = nullptr; // Note: must be a copy other than reference as it is modified thereafter. auto deps = this->manager_->func_graph_parents_total(fg); - if (deps.empty()) { this->parent_analysis_[fg] = nullptr; return; diff --git a/mindspore/python/mindspore/nn/layer/thor_layer.py b/mindspore/python/mindspore/nn/layer/thor_layer.py index 6f7100154d9..98f00597716 100644 --- a/mindspore/python/mindspore/nn/layer/thor_layer.py +++ b/mindspore/python/mindspore/nn/layer/thor_layer.py @@ -187,7 +187,9 @@ class DenseThor(Cell): x = self.bias_add(x, self.bias) if self.activation_flag: x = self.activation(x) - return x + # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, + # for it's used in 'save_gradient' gradient procedure. + return F.depend(x, self.matrix_g) def extend_repr(self): s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) @@ -644,7 +646,9 @@ class EmbeddingThor(Cell): output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) output = self.reshape(output_for_reshape, out_shape) - return output + # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, + # for it's used in 'save_gradient' gradient procedure. + return F.depend(output, self.matrix_g) def extend_repr(self): s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format( @@ -949,4 +953,6 @@ class EmbeddingLookupThor(Cell): axis = _make_axis_range(F.rank(indices), F.rank(out)) clip_by_norm = ClipByNorm(axis) out = clip_by_norm(out, self.max_norm) - return out + # We use Depend to make 'self.matrix_g' as primal graph's weight parameter, + # for it's used in 'save_gradient' gradient procedure. + return F.depend(out, self.matrix_g)