!63201 optimize trace in Cell

Merge pull request !63201 from NaCN/optimize_trace
This commit is contained in:
i-robot 2023-12-19 06:29:15 +00:00 committed by Gitee
commit 370abc06b2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 43 additions and 22 deletions

View File

@ -50,6 +50,8 @@
"mindspore/mindspore/python/mindspore/ops/composite/math_ops.py" "unused-import"
"mindspore/mindspore/python/mindspore/ops/primitive.py" "assignment-from-none"
"mindspore/mindspore/python/mindspore/ops/primitive.py" "protected-access"
"mindspore/mindspore/python/mindspore/ops/_tracefunc.py" "protected-access"
"mindspore/mindspore/python/mindspore/ops/_tracefunc.py" "function-redefined"
"mindspore/mindspore/python/mindspore/nn/cell.py" "assignment-from-none"
"mindspore/mindspore/python/mindspore/_extends/parse/resources.py" "bad-whitespace"
"mindspore/mindspore/python/mindspore/_extends/parse/parser.py" "broad-except"

View File

@ -43,8 +43,6 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.parallel.shard import Shard
from mindspore._check_jit_forbidden_api import jit_forbidden_register
from mindspore.common._decorator import deprecated
from mindspore._c_expression import PackExpander
from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
class Cell(Cell_):
@ -667,9 +665,6 @@ class Cell(Cell_):
args = bound_arguments.args
kwargs = bound_arguments.kwargs
if PackFunc.is_tracing():
return self._run_tracefunc(*args, **kwargs)
if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
self.check_names_and_refresh_name()
self._is_check_and_refresh = True
@ -2519,22 +2514,6 @@ class Cell(Cell_):
f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
def _run_tracefunc(self, *args, **kwargs):
""" Run Packed Cell in Pack."""
args = self._mixed_precision_cast(args)
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
if not PackFunc.current.is_pynative_mode and need_subgraph:
expander = PackExpander.get_instance()
args = expander.begin_subgraph(self, *args)
args = [_convert_tensor(a) for a in args]
output = self._run_construct(args, kwargs)
ret = expander.end_subgraph(self, output)
output = _convert_tensor(ret)
else:
with _SetMixedPrecision(self):
output = self._run_construct(args, kwargs)
return output
def _mixed_precision_cast(self, inputs):
mixed_type = self.get_mixed_precision_type()
if mixed_type == MixedPrecisionType.NOTSET:

View File

@ -18,6 +18,7 @@ import types
import textwrap
import inspect
import os
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore.ops.primitive import _RunOpHook, Primitive
from mindspore._c_expression import PackExpander, PackNode
@ -60,6 +61,45 @@ def _convert_tensor(node):
return tuple(_convert_tensor(e) for e in node)
return node
class PackFunc:
pass
def _trace_cell_call(self, *args, **kwargs):
""" Run Packed Cell in Pack."""
if self.__class__.construct is nn.Cell.construct:
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
if kwargs:
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
bound_arguments.apply_defaults()
args = bound_arguments.args
kwargs = bound_arguments.kwargs
args = self._mixed_precision_cast(args)
need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
if not PackFunc.current.is_pynative_mode and need_subgraph:
expander = PackExpander.get_instance()
args = expander.begin_subgraph(self, *args)
args = [_convert_tensor(a) for a in args]
output = self._run_construct(args, kwargs)
ret = expander.end_subgraph(self, output)
output = _convert_tensor(ret)
else:
with _SetMixedPrecision(self):
output = self._run_construct(args, kwargs)
return output
class _PackHook:
"""Hook for trace run"""
def __init__(self):
self.origin_call = nn.Cell.__call__
def __enter__(self):
nn.Cell.__call__ = _trace_cell_call
return self
def __exit__(self, *err):
nn.Cell.__call__ = self.origin_call
class PackFunc(Primitive):
"""pack function with lazy expander"""
@ -114,7 +154,7 @@ class PackFunc(Primitive):
return _convert_tensor(ret)
def _run_op(self, args):
with _RunOpHook(PackFunc._trace_run_op):
with _RunOpHook(PackFunc._trace_run_op), _PackHook():
fun_args = [_convert_tensor(a) for a in args]
ret = self.func(*fun_args, **self.kwargs)
return ret