forked from mindspore-Ecosystem/mindspore
!49770 Optimize replacing output method
Merge pull request !49770 from ZPaC/optimize-replacing-output-method
This commit is contained in:
commit
070ef4807f
|
@ -130,9 +130,23 @@ AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(origin_output);
|
||||
MS_EXCEPTION_IF_NULL(origin_output->abstract());
|
||||
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(origin_output, kIndex0);
|
||||
auto real_output = kernel_with_index.first;
|
||||
if (!IsPrimitiveCNode(real_output, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "Tuple output is not a MakeTuple node: " << real_output->DebugString();
|
||||
}
|
||||
AnfNodePtrList tuple_inputs;
|
||||
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (const auto &element : tuple_elements) {
|
||||
for (size_t i = kIndex0; i < tuple_elements.size(); i++) {
|
||||
// If tuple input is a ValueNode, use it as new tuple's input.
|
||||
const auto tuple_input = real_output->cast<CNodePtr>()->input(i + kSizeOne);
|
||||
if (tuple_input->isa<Parameter>() || tuple_input->isa<ValueNode>()) {
|
||||
MS_LOG(INFO) << "Use " << tuple_input->DebugString() << " as replaced output.";
|
||||
tuple_inputs.emplace_back(tuple_input);
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &element = tuple_elements[i];
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
|
||||
if (!tensor_abstract) {
|
||||
|
|
|
@ -396,13 +396,12 @@ class Model:
|
|||
def _get_metrics(self):
|
||||
"""Get metrics local values."""
|
||||
metrics = dict()
|
||||
# Embedding cache server as a storage service, no need to execute eval, just give fake metrics.
|
||||
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
||||
# There's no need for server to execute eval, just give fake metrics.
|
||||
for key, value in self._metric_fns.items():
|
||||
if not is_embedding_cache_server:
|
||||
if not _is_role_pserver():
|
||||
metrics[key] = value.eval()
|
||||
else:
|
||||
metrics[key] = 0
|
||||
metrics[key] = 1
|
||||
return metrics
|
||||
|
||||
def _get_scaling_sens(self):
|
||||
|
|
Loading…
Reference in New Issue