!49770 Optimize replacing output method

Merge pull request !49770 from ZPaC/optimize-replacing-output-method
This commit is contained in:
i-robot 2023-03-06 08:31:34 +00:00 committed by Gitee
commit 070ef4807f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 18 additions and 5 deletions

View File

@ -130,9 +130,23 @@ AnfNodePtr CreateReplacedOutputNode(const FuncGraphPtr &func_graph, const AnfNod
MS_EXCEPTION_IF_NULL(origin_output);
MS_EXCEPTION_IF_NULL(origin_output->abstract());
if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(origin_output, kIndex0);
auto real_output = kernel_with_index.first;
if (!IsPrimitiveCNode(real_output, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Tuple output is not a MakeTuple node: " << real_output->DebugString();
}
AnfNodePtrList tuple_inputs;
auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
for (const auto &element : tuple_elements) {
for (size_t i = kIndex0; i < tuple_elements.size(); i++) {
// If tuple input is a ValueNode, use it as new tuple's input.
const auto tuple_input = real_output->cast<CNodePtr>()->input(i + kSizeOne);
if (tuple_input->isa<Parameter>() || tuple_input->isa<ValueNode>()) {
MS_LOG(INFO) << "Use " << tuple_input->DebugString() << " as replaced output.";
tuple_inputs.emplace_back(tuple_input);
continue;
}
const auto &element = tuple_elements[i];
MS_EXCEPTION_IF_NULL(element);
auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
if (!tensor_abstract) {

View File

@ -396,13 +396,12 @@ class Model:
def _get_metrics(self):
"""Get metrics local values."""
metrics = dict()
# Embedding cache server as a storage service, no need to execute eval, just give fake metrics.
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
# There's no need for server to execute eval, just give fake metrics.
for key, value in self._metric_fns.items():
if not is_embedding_cache_server:
if not _is_role_pserver():
metrics[key] = value.eval()
else:
metrics[key] = 0
metrics[key] = 1
return metrics
def _get_scaling_sens(self):