!33454 Resolve the thor layer's parameter missing during grad operation.
Merge pull request !33454 from 张清华/opt
This commit is contained in:
commit
b350e16c13
|
@ -31,6 +31,7 @@ FuncGraphLoopBreaker &FuncGraphLoopBreaker::Inst() {
|
|||
static FuncGraphLoopBreaker mgr;
|
||||
return mgr;
|
||||
}
|
||||
|
||||
void FuncGraphLoopBreaker::BreakLoop() {
|
||||
MS_LOG(INFO) << "Size of not recycled graph before break loop is:" << func_set_.size();
|
||||
std::list<FuncGraphBasePtr> func_list;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "utils/convert_utils_base.h"
|
||||
#include "utils/counter.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -233,7 +234,9 @@ void FuncGraphManager::Init() {
|
|||
}
|
||||
|
||||
FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The parameter 'fg' should not be null.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString();
|
||||
func_graph_parents_total_->Recompute(fg);
|
||||
MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString();
|
||||
|
@ -883,7 +886,14 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(
|
|||
// Append all the fvs in fg.
|
||||
auto &fvs = fg->free_variables();
|
||||
for (auto fv : fvs) {
|
||||
parents->add(fv.first->func_graph());
|
||||
auto fv_node = fv.first;
|
||||
auto fv_func_graph = fv_node->func_graph();
|
||||
if (fv_func_graph == nullptr) {
|
||||
MS_LOG(INFO) << "Meet a FV '" << fv_node->DebugString() << "' whose func graph is null, during seeking for "
|
||||
<< fg->ToString() << "\nFV: " << trace::GetDebugInfo(fv_node->debug_info());
|
||||
continue;
|
||||
}
|
||||
parents->add(fv_func_graph);
|
||||
}
|
||||
|
||||
// Search the fv in fg's child func graph.
|
||||
|
@ -906,7 +916,8 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
mindspore::HashMap<FuncGraphPtr, FuncGraphSetPtr> seen_fgs;
|
||||
fg->seen_ = 1;
|
||||
func_graph_parents_total_analysis_[fg].update(SeekParents(fg, &seen_fgs));
|
||||
auto parents = SeekParents(fg, &seen_fgs);
|
||||
func_graph_parents_total_analysis_[fg].update(parents);
|
||||
fg->seen_ = 0;
|
||||
}
|
||||
|
||||
|
@ -920,7 +931,6 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
|
|||
this->parent_analysis_[fg] = nullptr;
|
||||
// Note: must be a copy other than reference as it is modified thereafter.
|
||||
auto deps = this->manager_->func_graph_parents_total(fg);
|
||||
|
||||
if (deps.empty()) {
|
||||
this->parent_analysis_[fg] = nullptr;
|
||||
return;
|
||||
|
|
|
@ -187,7 +187,9 @@ class DenseThor(Cell):
|
|||
x = self.bias_add(x, self.bias)
|
||||
if self.activation_flag:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
|
||||
# for it's used in 'save_gradient' gradient procedure.
|
||||
return F.depend(x, self.matrix_g)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
||||
|
@ -644,7 +646,9 @@ class EmbeddingThor(Cell):
|
|||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
output = self.reshape(output_for_reshape, out_shape)
|
||||
return output
|
||||
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
|
||||
# for it's used in 'save_gradient' gradient procedure.
|
||||
return F.depend(output, self.matrix_g)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
|
||||
|
@ -949,4 +953,6 @@ class EmbeddingLookupThor(Cell):
|
|||
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
||||
clip_by_norm = ClipByNorm(axis)
|
||||
out = clip_by_norm(out, self.max_norm)
|
||||
return out
|
||||
# We use Depend to make 'self.matrix_g' as primal graph's weight parameter,
|
||||
# for it's used in 'save_gradient' gradient procedure.
|
||||
return F.depend(out, self.matrix_g)
|
||||
|
|
Loading…
Reference in New Issue