!33454 Resolve the thor layer's parameter missing during grad operation.

Merge pull request !33454 from 张清华/opt
This commit is contained in:
i-robot 2022-04-24 11:08:01 +00:00 committed by Gitee
commit b350e16c13
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 24 additions and 7 deletions

View File

@ -31,6 +31,7 @@ FuncGraphLoopBreaker &FuncGraphLoopBreaker::Inst() {
static FuncGraphLoopBreaker mgr;
return mgr;
}
void FuncGraphLoopBreaker::BreakLoop() {
MS_LOG(INFO) << "Size of not recycled graph before break loop is:" << func_set_.size();
std::list<FuncGraphBasePtr> func_list;

View File

@ -24,6 +24,7 @@
#include "ir/func_graph.h"
#include "utils/convert_utils_base.h"
#include "utils/counter.h"
#include "utils/trace_base.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
@ -233,7 +234,9 @@ void FuncGraphManager::Init() {
}
FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
MS_EXCEPTION_IF_NULL(fg);
if (fg == nullptr) {
MS_LOG(EXCEPTION) << "The parameter 'fg' should not be null.";
}
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
func_graph_parents_total_->Recompute(fg);
MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
@ -883,7 +886,14 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(
// Append all the fvs in fg.
auto &fvs = fg->free_variables();
for (auto fv : fvs) {
parents->add(fv.first->func_graph());
auto fv_node = fv.first;
auto fv_func_graph = fv_node->func_graph();
if (fv_func_graph == nullptr) {
MS_LOG(INFO) << "Meet a FV '" << fv_node->DebugString() << "' whose func graph is null, during seeking for "
<< fg->ToString() << "\nFV: " << trace::GetDebugInfo(fv_node->debug_info());
continue;
}
parents->add(fv_func_graph);
}
// Search the fv in fg's child func graph.
@ -906,7 +916,8 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(fg);
mindspore::HashMap<FuncGraphPtr, FuncGraphSetPtr> seen_fgs;
fg->seen_ = 1;
func_graph_parents_total_analysis_[fg].update(SeekParents(fg, &seen_fgs));
auto parents = SeekParents(fg, &seen_fgs);
func_graph_parents_total_analysis_[fg].update(parents);
fg->seen_ = 0;
}
@ -920,7 +931,6 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
this->parent_analysis_[fg] = nullptr;
// Note: must be a copy other than reference as it is modified thereafter.
auto deps = this->manager_->func_graph_parents_total(fg);
if (deps.empty()) {
this->parent_analysis_[fg] = nullptr;
return;

View File

@ -187,7 +187,9 @@ class DenseThor(Cell):
x = self.bias_add(x, self.bias)
if self.activation_flag:
x = self.activation(x)
return x
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
# for it's used in 'save_gradient' gradient procedure.
return F.depend(x, self.matrix_g)
def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
@ -644,7 +646,9 @@ class EmbeddingThor(Cell):
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, out_shape)
return output
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
# for it's used in 'save_gradient' gradient procedure.
return F.depend(output, self.matrix_g)
def extend_repr(self):
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
@ -949,4 +953,6 @@ class EmbeddingLookupThor(Cell):
axis = _make_axis_range(F.rank(indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
return out
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
# for it's used in 'save_gradient' gradient procedure.
return F.depend(out, self.matrix_g)